from re import A
import shutil
import os
import torch
import numpy as np
from torchvision.models import resnet18, resnet34, resnet50, resnet101
from utils.data_utils import get_imagenetr_images, get_imagenet_images, get_cifar_images, get_scipy_images, get_objectnet_images, get_cars_images
from torchvision import transforms
from open_clip.zero_shot_classifier import build_zero_shot_classifier
import random
import torch.nn as nn
import torch.backends.cudnn as cudnn

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True

def get_model(model_name, num_classes, patch_size, device, arch='base', d_pre='imagenet', pretrained=True, reg=None, mode='linear'):
    preprocess = None
    tokenizer = None
    if 'open_clip' in model_name:
        import open_clip

        if d_pre == 'laion' and patch_size == 16:
            pretrained = 'laion2b_s34b_b88k'
        elif d_pre == 'laion' and patch_size == 32:
            pretrained = 'laion2b_s34b_b79k'
        elif d_pre == '400m':
            pretrained = 'laion400m_e32'
        elif d_pre == 'datacomp':
            pretrained = 'datacomp_xl_s13b_b90k'
        elif d_pre == 'openai':
            pretrained = 'openai'
        else:
            pretrained = None
        if 'resnet' in model_name:
            model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained=d_pre)
            tokenizer = open_clip.get_tokenizer(f'RN50')
        elif 'convnext' in model_name:
#            model, _, preprocess = open_clip.create_model_and_transforms('convnext_xxlarge', pretrained='laion2b_s34b_b82k_augreg')
            model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k')
            tokenizer = open_clip.get_tokenizer('convnext_base_w')
        else:
            model, _, preprocess = open_clip.create_model_and_transforms(f'ViT-B-{patch_size}', pretrained=pretrained)
            if d_pre == '100m':
                state_dict = torch.load('open_clip_100m.pt', map_location='cpu')
                model.load_state_dict(state_dict)
            tokenizer = open_clip.get_tokenizer(f'ViT-B-{patch_size}')

        model = model.to(device)

        if mode == 'zero-shot':
            from open_clip.zero_shot_metadata import IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
            zero_shot_weights = build_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device=device)
            model.zero_shot_weights = zero_shot_weights
    elif 'resnet' in model_name:
        if model_name == 'resnet18':
            model = resnet18(pretrained=True)
        elif model_name == 'resnet34':
            model = resnet34(pretrained=True)
        if model_name == 'resnet50':
            from torchvision.models import ResNet50_Weights
            model = resnet50(weights=ResNet50_Weights.DEFAULT)
        elif model_name == 'resnet101':
            from torchvision.models import ResNet101_Weights
            model = resnet101(weights=ResNet101_Weights.DEFAULT)
        elif model_name == 'resnet152':
            from torchvision.models import ResNet152_Weights
            model = resnet152(weights=ResNet152_Weights.DEFAULT)
        # imagenet preprocess
        model = model.to(device)

        preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
    elif model_name == 'clip':
        import clip
        model, preprocess = clip.load(f'ViT-B/{patch_size}', device, jit=False)
        tokenizer = clip.tokenize
    elif model_name == 'align':
        if d_pre == 'coyo':
#            from transformers import AlignProcessor, AlignModel
            from lib.align import AlignProcessor, AlignModel
            from lib.align.zero_shot_classifier import build_zero_shot_classifier as build_align_zero_shot_classifier
            model_path = "kakaobrain/align-base"
 #           model_path = "kakaobrain/coyo-align-b7-base"
            processor= AlignProcessor.from_pretrained(model_path)
            model = AlignModel.from_pretrained(model_path)
            tokenizer = processor.tokenizer
#            preprocess = preprocess.image_processor
            preprocess = lambda image: processor(images=image, return_tensors="pt")
            model = model.to(device)
            if mode == 'zero-shot':
                from open_clip.zero_shot_metadata import IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
                zero_shot_weights = build_align_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device=device)
                model.text_embeds = zero_shot_weights.cuda()
    elif 'siglip' in model_name:
        from lib.siglip2 import Siglip2Model
        from lib.siglip import SiglipModel
        from transformers import AutoProcessor, AutoModel
#        model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") #, torch_dtype=torch.float16, attn_implementation="eager")
#        breakpoint()
        if model_name == 'siglip2':
            path = 'google/siglip2-base-patch16-224'
        elif model_name == 'siglip':
            path = 'google/siglip-base-patch16-224'
        model = SiglipModel.from_pretrained(path)
        preprocessor = AutoProcessor.from_pretrained(path)
        tokenizer = preprocessor.tokenizer
        preprocess = lambda image: preprocessor(images=image, return_tensors="pt")
        model = model.to(device)
        if mode == 'zero-shot':
            from open_clip.zero_shot_metadata import IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
            from lib.siglip2.zero_shot_classifier import build_zero_shot_classifier as build_siglip2_zero_shot_classifier
            zero_shot_weights = build_siglip2_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device=device)
            model.text_embeds = zero_shot_weights.cuda()

    elif model_name == 'vit' and d_pre == 'in1k_orig':
        if arch == 'base':
            if patch_size == 16:
                if reg == 'lora':
                    from utils.lora_torch_vit import vit_b_16, ViT_B_16_Weights
                else:
                    from torchvision.models import vit_b_16, ViT_B_16_Weights
                model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
                preprocess = ViT_B_16_Weights.IMAGENET1K_V1.transforms()
            elif patch_size == 32:
                if reg == 'lora':
                    from utils.lora_torch_vit import vit_b_32, ViT_B_32_Weights
                else:
                    from torchvision.models import vit_b_32, ViT_B_32_Weights
                model = vit_b_32(weights=ViT_B_32_Weights.DEFAULT)
                preprocess = ViT_B_32_Weights.IMAGENET1K_V1.transforms()
        elif arch == 'large':
            if patch_size == 16:
                if reg == 'lora':
                    from utils.lora_torch_vit import vit_l_16, ViT_L_16_Weights
                else:
                    from torchvision.models import vit_l_16, ViT_L_16_Weights
                model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
                preprocess = ViT_L_16_Weights.IMAGENET1K_V1.transforms()
    elif model_name == 'dinov2' and 'dinov2' in d_pre:
        from transformers import AutoImageProcessor, AutoModelForImageClassification
        preprocess = AutoImageProcessor.from_pretrained(f'facebook/dinov2-{arch}-imagenet1k-1-layer')
        model = AutoModelForImageClassification.from_pretrained(f'facebook/dinov2-{arch}-imagenet1k-1-layer')
        model = model.to(device)

    elif model_name == 'vit':
        import timm
        if d_pre == 'imagenet':
            arch = f'vit_{arch}_patch{patch_size}_224.augreg_in21k_ft_in1k'
        elif d_pre == 'orig':
            arch = f'vit_{arch}_patch{patch_size}_224.orig_in21k_ft_in1k'
        elif d_pre == 'in1k':
            arch = f'vit_{arch}_patch{patch_size}_224.augreg_in1k'
        elif d_pre == 'laion':
            arch = f'vit_{arch}_patch{patch_size}_clip_224.laion2b_ft_in1k'
        elif d_pre == '400m':
            arch = f'immich-app/ViT-B-32__laion400m_e32'
        elif d_pre == 'openai':
            arch = f'vit_{arch}_patch{patch_size}_clip_224.openai_ft_in1k'
        elif d_pre == 'laion_only':
            arch = f'vit_{arch}_patch{patch_size}_clip_224.laion2b'
        elif d_pre == 'openai_only':
            arch = f'vit_{arch}_patch{patch_size}_clip_224.openai'
        elif d_pre == 'imagenet_only':
            arch = f'vit_{arch}_patch{patch_size}_224.augreg_in21k'
        elif d_pre == 'miil_only':
            arch = f'vit_{arch}_patch{patch_size}_224_miil.in21k'
        elif d_pre == 'orig_only':
            arch = f'vit_{arch}_patch{patch_size}_224.orig_in21k'
        elif d_pre == 'sam':
            arch = f'vit_{arch}_patch{patch_size}_224.sam_in1k'
        elif d_pre == 'sw':
            arch = f'vit_{arch}_patch{patch_size}_rpn_224.sw_in1k'
        elif d_pre == 'miil':
            arch = f'vit_{arch}_patch{patch_size}_224_miil.in21k_ft_in1k'
        elif d_pre == 'none':
            arch = f'vit_{arch}_patch{patch_size}_224.augreg_in21k_ft_in1k'
            pretrained=None
        model = timm.create_model(arch, pretrained=pretrained)
        model = model.to(device)
        data_config = timm.data.resolve_model_data_config(model)
        preprocess = timm.data.create_transform(**data_config, is_training=False)
    return model, preprocess, tokenizer


"""
    Copy from another
"""

def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()


def refine_classname(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ')
    return class_names


def save_checkpoint(state, args, is_best=False, filename='checkpoint.pth.tar'):
    savefile = os.path.join(args.model_folder, filename)
    bestfile = os.path.join(args.model_folder, 'model_best.pth.tar')
    torch.save(state, savefile)
    if is_best:
        shutil.copyfile(savefile, bestfile)
        print ('saved best file')


def assign_learning_rate(optimizer, new_lr):
    for param_group in optimizer.param_groups:
        param_group["lr"] = new_lr


def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length


def cosine_lr(optimizer, base_lr, warmup_length, steps):
    def _lr_adjuster(step):
        if step < warmup_length:
            lr = _warmup_lr(base_lr, warmup_length, step)
        else:
            e = step - warmup_length
            es = steps - warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
        assign_learning_rate(optimizer, lr)
        return lr
    return _lr_adjuster


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
