diff --git a/clip/clip.py b/clip/clip.py index e4640e5..8f37ba8 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -75,7 +75,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a if not jit: model = build_model(model.state_dict()).to(device) - if device == "cpu": + if str(device) == "cpu": model.float() return model, transform @@ -98,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a patch_device(model.encode_text) # patch dtype to float32 on CPU - if device == "cpu": + if str(device) == "cpu": float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node()