correctly checks for cpu for torch.device("cpu")

This commit is contained in:
Jong Wook Kim 2021-01-30 02:53:03 +09:00
parent decf1904ae
commit 3a6e8d0034
1 changed files with 2 additions and 2 deletions

View File

@ -75,7 +75,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
if not jit:
model = build_model(model.state_dict()).to(device)
if device == "cpu":
if str(device) == "cpu":
model.float()
return model, transform
@ -98,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if device == "cpu":
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()