diff --git a/tests/test_consistency.py b/tests/test_consistency.py index 29d343d..f2c6fd4 100644 --- a/tests/test_consistency.py +++ b/tests/test_consistency.py @@ -9,7 +9,7 @@ import clip @pytest.mark.parametrize('model_name', clip.available_models()) def test_consistency(model_name): device = "cpu" - jit_model, transform = clip.load(model_name, device=device) + jit_model, transform = clip.load(model_name, device=device, jit=True) py_model, _ = clip.load(model_name, device=device, jit=False) image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)