diff --git a/clip/model.py b/clip/model.py index 422a34a..f2c95c4 100644 --- a/clip/model.py +++ b/clip/model.py @@ -199,7 +199,7 @@ class Transformer(nn.Module): return self.resblocks(x) -class VisualTransformer(nn.Module): +class VisionTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): super().__init__() self.input_resolution = input_resolution @@ -266,7 +266,7 @@ class CLIP(nn.Module): ) else: vision_heads = vision_width // 64 - self.visual = VisualTransformer( + self.visual = VisionTransformer( input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width,