correctly checks for cpu for torch.device("cpu")
This commit is contained in:
parent
decf1904ae
commit
3a6e8d0034
@ -75,7 +75,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||||||
|
|
||||||
if not jit:
|
if not jit:
|
||||||
model = build_model(model.state_dict()).to(device)
|
model = build_model(model.state_dict()).to(device)
|
||||||
if device == "cpu":
|
if str(device) == "cpu":
|
||||||
model.float()
|
model.float()
|
||||||
return model, transform
|
return model, transform
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||||||
patch_device(model.encode_text)
|
patch_device(model.encode_text)
|
||||||
|
|
||||||
# patch dtype to float32 on CPU
|
# patch dtype to float32 on CPU
|
||||||
if device == "cpu":
|
if str(device) == "cpu":
|
||||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||||
float_node = float_input.node()
|
float_node = float_input.node()
|
||||||
|
Loading…
Reference in New Issue
Block a user