diff --git a/clip/clip.py b/clip/clip.py index 6ce5565..9d6f4d6 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -68,11 +68,15 @@ def _download(url: str, root: str): return download_target +def _convert_image_to_rgb(image): + return image.convert("RGB") + + def _transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), - lambda image: image.convert("RGB"), + _convert_image_to_rgb, ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ])