import os
import clip
from PIL import Image
from torchvision import transforms
from .constants import CACHE_DIR

BLIP1_MODELS = ['blip-base', 'blip-base-14M', 'blip-large', 'blip-large-random', 'blip-base-random', 'blip-coco-base-march-21-three-losses']
BLIP1_LM_MODELS = ['blip-coco-large-caption']
BLIP1_ITM_MODELS = ['blip-coco-base', 'blip-flickr-base', 'blip-coco-large-retrieval', 'blip-flickr-large-retrieval']
BLIP2_MODELS = ['blip2-pretrain', 'blip2-pretrain_vitL', 'blip2-coco',
                'blip2-pretrain_flant5xl', 'blip2-pretrain_flant5xl_vitL', 'blip2-pretrain_flant5xxl', 'blip2-caption_coco_flant5xl']
BLIP_MODELS = BLIP1_MODELS + BLIP1_LM_MODELS + BLIP1_ITM_MODELS + BLIP2_MODELS
LLM_MODELS = ['bart-large', 'bart-base', 'flan-t5-large', 'flan-t5-xl', 'opt-2.7b', 'opt-6.7b']
BLIP_MODELS_RANDOM = ['blip-large-random', 'blip-base-random', 'blip-large-random-mean-0',
                      'blip-base-random-mean-0', 'blip-large-random-mean-1']
def get_model(model_name, device, root_dir=CACHE_DIR):
    """
    Helper function that returns a model and a potential image preprocessing function.
    """
    if model_name == 'blip-base':
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant="blip-base")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == 'blip-base-14M':
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant="blip-base-14M")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == 'blip-large':
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant="blip-large")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    elif model_name in ['blip-large-random', 'blip-base-random', 'blip-base-14M-random']:
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant=model_name[:-7])
        import torch
        def rand_rgb_image(image_size):
            return torch.normal(0.45, 0.25, size=(3, image_size, image_size))
        image_preprocess = transforms.Compose([
                        lambda _: rand_rgb_image(384)])  
        return blip_model, image_preprocess
    
    elif model_name in ['blip-large-random-mean-0', 'blip-base-random-mean-0', 'blip-base-14M-random-mean-0']:
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant=model_name[:-14])
        import torch
        def rand_rgb_image(image_size):
            return torch.normal(0.0, 0.25, size=(3, image_size, image_size))
        image_preprocess = transforms.Compose([
                        lambda _: rand_rgb_image(384)])  
        return blip_model, image_preprocess
    
    elif model_name in ['blip-large-random-mean-1']:
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant=model_name[:-14])
        import torch
        def rand_rgb_image(image_size):
            return torch.normal(1.0, 0.25, size=(3, image_size, image_size))
        image_preprocess = transforms.Compose([
                        lambda _: rand_rgb_image(384)])  
        return blip_model, image_preprocess
    
    elif model_name in ['bart-large', 'bart-base']:
        from .bart_models import BartModelWrapper
        bart_model = BartModelWrapper(root_dir=root_dir, device=device, variant=model_name)
        image_preprocess = transforms.Compose([ # This is however not used
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return bart_model, image_preprocess
    
    elif model_name in ['blip2-pretrain', 'blip2-pretrain_vitL', 'blip2-coco']:
        from .blip2_models import BLIP2QFormerModelWrapper
        variant = model_name.split('-')[1]
        if variant == 'coco':
            size = 364
        else:
            size = 224
        blip2_model = BLIP2QFormerModelWrapper(root_dir=root_dir, device=device, variant=model_name.split('-')[1])
        image_preprocess = transforms.Compose([
                        transforms.Resize((size, size),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip2_model, image_preprocess
    
    elif model_name in ['blip2-pretrain_flant5xl', 'blip2-pretrain_flant5xl_vitL', 'blip2-pretrain_flant5xxl', 'blip2-caption_coco_flant5xl']:
        from .blip2_models import BLIP2FlanT5ModelWrapper
        variant = model_name.split('-')[1]
        blip2_model = BLIP2FlanT5ModelWrapper(root_dir=root_dir, device=device, variant=variant)
        if variant == 'caption_coco_flant5xl':
            size = 364
        else:
            size = 224
        image_preprocess = transforms.Compose([
                        transforms.Resize((size, size),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip2_model, image_preprocess
    
    elif model_name in ['flan-t5-large', 'flan-t5-xl']:
        from .flan_t5_models import FlanT5ModelWrapper
        flan_t5_model = FlanT5ModelWrapper(root_dir=root_dir, device=device, variant=model_name)
        image_preprocess = transforms.Compose([ # This is however not used
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return flan_t5_model, image_preprocess
    
    elif model_name in ['opt-2.7b', 'opt-6.7b']:
        from .opt_models import OptModelWrapper
        opt_model = OptModelWrapper(root_dir=root_dir, device=device, variant=model_name)
        image_preprocess = transforms.Compose([ # This is however not used
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return opt_model, image_preprocess
    
    elif model_name == 'blip-coco-large-caption':
        from .blip_models import BLIPModelWrapperFull
        blip_model = BLIPModelWrapperFull(root_dir=root_dir, device=device, variant="blip-coco-large-caption")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == 'blip-coco-base-march-21-three-losses':
        from .blip_models import BLIPModelWrapperFullLocal
        blip_model = BLIPModelWrapperFullLocal(root_dir=root_dir, device=device, variant="blip-coco-base-march-21-three-losses")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess

    elif model_name == 'blip-coco-base-march-21-three-losses-random':
        from .blip_models import BLIPModelWrapperFullLocal
        blip_model = BLIPModelWrapperFullLocal(root_dir=root_dir, device=device, variant="blip-coco-base-march-21-three-losses")
        import torch
        def rand_rgb_image(image_size):
            return torch.normal(0.45, 0.25, size=(3, image_size, image_size))
        image_preprocess = transforms.Compose([
                        lambda _: rand_rgb_image(384)])  
        return blip_model, image_preprocess
    
    elif "openai-clip" in model_name:
        from .clip_models import CLIPWrapper
        variant = model_name.split(":")[1]
        model, image_preprocess = clip.load(variant, device=device, download_root=root_dir)
        clip_model = CLIPWrapper(model, device) 
        return clip_model, image_preprocess

    elif model_name == "blip-flickr-base":
        from .blip_models import BLIPModelWrapper
        blip_model = BLIPModelWrapper(root_dir=root_dir, device=device, variant="blip-flickr-base")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == "blip-coco-base":
        from .blip_models import BLIPModelWrapper
        blip_model = BLIPModelWrapper(root_dir=root_dir, device=device, variant="blip-coco-base")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == "blip-coco-large-retrieval":
        from .blip_models import BLIPModelWrapper
        blip_model = BLIPModelWrapper(root_dir=root_dir, device=device, variant="blip-coco-large-retrieval")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == "blip-flickr-large-retrieval":
        from .blip_models import BLIPModelWrapper
        blip_model = BLIPModelWrapper(root_dir=root_dir, device=device, variant="blip-flickr-large-retrieval")
        image_preprocess = transforms.Compose([
                        transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])  
        return blip_model, image_preprocess
    
    elif model_name == "xvlm-flickr":
        from .xvlm_models import XVLMWrapper
        xvlm_model = XVLMWrapper(root_dir=root_dir, device=device, variant="xvlm-flickr")
        image_preprocess = transforms.Compose([
                            transforms.Resize((384, 384), interpolation=Image.BICUBIC),
                            transforms.ToTensor(),
                            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),])
        return xvlm_model, image_preprocess
    
    elif model_name == "xvlm-coco":
        from .xvlm_models import XVLMWrapper
        xvlm_model = XVLMWrapper(root_dir=root_dir, device=device, variant="xvlm-coco")
        image_preprocess = transforms.Compose([
                            transforms.Resize((384, 384), interpolation=Image.BICUBIC),
                            transforms.ToTensor(),
                            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),])
        return xvlm_model, image_preprocess
    
    elif model_name == "flava":
        from .flava import FlavaWrapper
        flava_model = FlavaWrapper(root_dir=root_dir, device=device)
        image_preprocess = None
        return flava_model, image_preprocess

    elif model_name == "NegCLIP":
        import open_clip
        from .clip_models import CLIPWrapper
        
        path = os.path.join(root_dir, "negclip.pth")
        if not os.path.exists(path):
            print("Downloading the NegCLIP model...")
            import gdown
            gdown.download(id="1ooVVPxB-tvptgmHlIMMFGV3Cg-IrhbRZ", output=path, quiet=False)
        model, _, image_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained=path, device=device)
        clip_model = CLIPWrapper(model, device) 
        return clip_model, image_preprocess

    elif model_name == "coca":
        import open_clip
        from .clip_models import CLIPWrapper
        model, _, image_preprocess = open_clip.create_model_and_transforms(model_name="coca_ViT-B-32", pretrained="laion2B-s13B-b90k", device=device)
        clip_model = CLIPWrapper(model, device) 
        return clip_model, image_preprocess
    
        
    elif "laion-clip" in model_name:
        import open_clip
        from .clip_models import CLIPWrapper
        variant = model_name.split(":")[1]
        model, _, image_preprocess = open_clip.create_model_and_transforms(model_name=variant, pretrained="laion2b_s34b_b79k", device=device)
        clip_model = CLIPWrapper(model, device) 
        return clip_model, image_preprocess
    
        
    else:
        raise ValueError(f"Unknown model {model_name}")
