import os
import torch
import torch.nn as nn
from torchvision import models
import timm
import clip
from robustness.model_utils import make_and_restore_model
from robustness.datasets import ImageNet


def get_normalization_stats(model_name):
    OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
    OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
    IMAGENET_STANDARD_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STANDARD_STD = (0.229, 0.224, 0.225)
    IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
    IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)

    if model_name.startswith('RN50') or model_name.startswith('L2-RN50'):
        return IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
    elif model_name == "CLIP-RN50":
        return OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
    elif model_name in [
        "densenet201_imagenet", "mobilenet_v2", "CORnet_RT",
        "inception_v3", "squeezenet1_1", "vgg16", "alexnet", "blip2", "VGG16-robust-l2-3"
    ]:
        return IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
    elif model_name in ['dinov2', 'dreamsim_vitb32', 'nomic']:
        return OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
    elif model_name == 'google_vit':
        return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
    else:
        raise ValueError(f"Unsupported encoder for normalization: {model_name}")


def load_model(model_name):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    mean, std = get_normalization_stats(model_name)

    if model_name.startswith('RN50') or model_name.startswith('L2-RN50'):
        ds = ImageNet("")
        checkpoint_map = {
            "RN50-0": "../encoders/resnet50_l2_eps0.ckpt",
            "RN50-robust-0.5": "../encoders/resnet50_linf_eps0.5.ckpt",
            "RN50-robust-1": "../encoders/resnet50_linf_eps1.0.ckpt",
            "RN50-robust-2": "../encoders/resnet50_linf_eps2.0.ckpt",
            "RN50-robust-4": "../encoders/resnet50_linf_eps4.0.ckpt",
            "RN50-robust-8": "../encoders/resnet50_linf_eps8.0.ckpt",
            "L2-RN50-robust-0.1": "../encoders/resnet50_l2_eps0.1.ckpt",
            "L2-RN50-robust-1": "../encoders/resnet50_l2_eps1.ckpt",
            "L2-RN50-robust-3": "../encoders/resnet50_l2_eps3.ckpt",
            "L2-RN50-robust-5": "../encoders/resnet50_l2_eps5.ckpt",
        }
        resume_path = checkpoint_map.get(model_name)
        if resume_path is None:
            raise ValueError(f"Unknown encoder: {model_name}")
        model, _ = make_and_restore_model(arch='resnet50', dataset=ds, resume_path=resume_path)
        if hasattr(model, 'normalizer'):
            model = model.model

    elif model_name == 'VGG16-robust-l2-3':
        ds = ImageNet("")
        resume_path = "../encoders/vgg16_bn_l2_eps3_fixed.ckpt"
        model, _ = make_and_restore_model(arch='vgg16_bn', dataset=ds, resume_path=resume_path)
        if hasattr(model, 'normalizer'):
            model = model.model

    elif model_name == "CLIP-RN50":
        model, _ = clip.load("RN50", device=device)

    elif model_name == "densenet201_imagenet":
        model = models.densenet201(pretrained=True).features.to(device)

    elif model_name == "mobilenet_v2":
        model = models.mobilenet_v2(pretrained=True, width_mult=1.0).features.to(device)

    elif model_name == "CORnet_RT":
        from cornet_rt import CORnet_RT
        model = CORnet_RT(times=5).to(device)
        checkpoint = torch.load('../encoders/cornet_rt-933c001c.pth', map_location=device)
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})

    elif model_name == "inception_v3":
        model = timm.create_model('inception_v3', pretrained=True).to(device)

    elif model_name == "squeezenet1_1":
        model = models.squeezenet1_1(pretrained=True).to(device)

    elif model_name == 'dinov2':
        model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').to(device)

    elif model_name == 'dreamsim_vitb32':
        from dreamsim import dreamsim
        model = DreamSimVitB32(pretrained=True).to(device)

    elif model_name == 'google_vit':
        model = models.vit_b_32(weights="DEFAULT").to(device)
        model.encode_image = lambda x: model(x)[:, 0]

    elif model_name == 'blip2':
        from transformers import Blip2ForConditionalGeneration
        blip2 = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-itm-vit-g")
        model = blip2.vision_model.to(device)
        model.encode_image = lambda x: model(x).last_hidden_state[:, 0]

    elif model_name == 'nomic':
        from transformers import AutoModel
        model = AutoModel.from_pretrained("nomic-ai/nomic-embed-vision-v1", trust_remote_code=True).to(device)

    elif model_name in ["vgg16", "alexnet"]:
        model = getattr(models, model_name)(pretrained=True).to(device)

    else:
        raise ValueError(f"Unsupported encoder: {model_name}")

    if not hasattr(model, 'encode_image'):
        model.encode_image = lambda x: model(x)

    model.eval()
    return model, (mean, std)


class DreamSimVitB32(nn.Module):
    def __init__(self, pretrained: bool = True, download_root: str = "./checkpoints/dreamsim_vitb32_checkpoints"):
        super().__init__()
        os.makedirs(download_root, exist_ok=True)
        model_wrapper, _ = dreamsim(pretrained=pretrained, dreamsim_type="open_clip_vitb32", device="cpu", cache_dir=download_root)
        self.model = model_wrapper.base_model.model.extractor_list[0].model
        if not pretrained:
            for layer in self.model.children():
                if hasattr(layer, "reset_parameters"):
                    layer.reset_parameters()

    def forward(self, x):
        return self.model(x)


class BrainEncoderWrapper(nn.Module):
    def __init__(self, encoder, brainmodel, activation, layer_name, mapping=None):
        super().__init__()
        self.encoder = encoder
        self.brainmodel = brainmodel
        self.activation = activation
        self.layer_name = layer_name
        self.mapping = mapping

    def forward(self, x):
        _ = self.encoder.encode_image(x) if hasattr(self.encoder, "encode_image") else self.encoder(x)
        x = self.activation[self.layer_name]
        if hasattr(self, "mapping") and self.mapping == "cnn":
            x = x.to(torch.float32)
        elif x.ndim == 3:
            x = x.mean(dim=1).to(torch.float32)
        else:
            x = torch.flatten(x, start_dim=1).to(torch.float32)
        x = self.brainmodel(x)
        return x
