correctly checks for cpu for torch.device("cpu")
This commit is contained in:
parent
decf1904ae
commit
3a6e8d0034
|
@ -75,7 +75,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
|
||||
if not jit:
|
||||
model = build_model(model.state_dict()).to(device)
|
||||
if device == "cpu":
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, transform
|
||||
|
||||
|
@ -98,7 +98,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if device == "cpu":
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
|
Loading…
Reference in New Issue