From 3a6e8d003417e6264fac91e92baf6216dc3427d2 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Sat, 30 Jan 2021 02:53:03 +0900 Subject: [PATCH] correctly checks for cpu for torch.device("cpu") --- clip/clip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index e4640e5..8f37ba8 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -75,7 +75,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a if not jit: model = build_model(model.state_dict()).to(device) - if device == "cpu": + if str(device) == "cpu": model.float() return model, transform @@ -98,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a patch_device(model.encode_text) # patch dtype to float32 on CPU - if device == "cpu": + if str(device) == "cpu": float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node()