import os
import math
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn

import numpy as np

from omegaconf import OmegaConf
from dinov2.configs import dinov2_default_config

from dinov2.models import build_model_from_cfg
import dinov2.utils.utils as dinov2_utils

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def rand_bbox(size, res,prop=None):
    W = size
    H = size
    cut_w = res
    cut_h = res
    if prop is not None:
        res = np.random.rand()*(prop[1]-prop[0])
        cut_w = int(res*W)
        cut_h = int(res*H)
    tx = np.random.randint(0,W-cut_w)
    ty = np.random.randint(0,H-cut_h)
    bbx1 = tx
    bby1 = ty
    return bbx1, bby1

def rand_sampling_mult(sizes,content_image,out_image,crop_size=128,num_crops=16,prop=None):
    bbxl=[]
    bbyl=[]
    crop_image = []
    tar_image = []

    for cc in range(num_crops):
        bbx1, bby1 = rand_bbox(sizes, crop_size,prop)
        crop_image.append(content_image[:,:,bby1:bby1+crop_size,bbx1:bbx1+crop_size])
        tar_image.append(out_image[:,:,bby1:bby1+crop_size,bbx1:bbx1+crop_size])
    crop_image = torch.cat(crop_image,dim=0)
    tar_image = torch.cat(tar_image,dim=0)
    return crop_image,tar_image

class DinoLoss(torch.nn.Module):
    def __init__(self, device, args=None):
        super(DinoLoss, self).__init__()

        self.device = device
        self.args = args

        self.set_transform()
        if self.args.dino_ckpt is not None:
            self.set_model_dino()
        elif self.args.mvit_ckpt is not None:
            self.set_model_mvit()
        else:
            self.set_model_vit()

        self.tau = 1.0
        self.batch = self.args.batch
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self.mask_dtype = torch.bool

        self.sim = nn.CosineSimilarity()
        self.sfm = nn.Softmax(dim=1)
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def set_transform(self):
        # from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
        IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
        IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
        try:
            from torchvision.transforms import InterpolationMode

            def _pil_interp(method):
                if method == 'bicubic':
                    return InterpolationMode.BICUBIC
                elif method == 'lanczos':
                    return InterpolationMode.LANCZOS
                elif method == 'hamming':
                    return InterpolationMode.HAMMING
                else:
                    # default bilinear, do we want to allow nearest?
                    return InterpolationMode.BILINEAR

            import timm.data.transforms as timm_transforms

            timm_transforms._pil_interp = _pil_interp
        except:
            from timm.data.transforms import _pil_interp

        self.preprocess_img = transforms.Compose([
                                    transforms.Resize([224, 224], interpolation=_pil_interp('bicubic')),
                                    # transforms.CenterCrop([224, 224]),
                                    _convert_image_to_rgb,
                                    transforms.ToTensor(),
                                    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ])
        self.preprocess_gen = transforms.Compose([
                                    transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
                                    transforms.Resize([224, 224], interpolation=_pil_interp('bicubic')),
                                    # transforms.CenterCrop([224, 224]),
                                    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ]) # IMAGENET_DEFAULT_MEAN (0.485, 0.456, 0.406)

    def set_model_dino(self):
        default_cfg = OmegaConf.create(dinov2_default_config)
        cfg = OmegaConf.load(self.args.dino_config)
        config = OmegaConf.merge(default_cfg, cfg)

        self.dino_model, _ = build_model_from_cfg(config, only_teacher=True)
        dinov2_utils.load_pretrained_weights(self.dino_model, self.args.dino_ckpt, "teacher")
        self.dino_model = self.dino_model.to(self.device)
    
    def set_model_vit(self):
        from transformers import ViTForImageClassification

        self.dino_model = ViTForImageClassification.from_pretrained(self.args.vit_ckpt)
        # google/vit-large-patch16-224 'google/vit-base-patch16-224'
        self.dino_model.to(self.device)
    
    def set_model_mvit(self):
        from mvit.config.defaults import assert_and_infer_cfg, get_cfg
        from mvit.models import build_model
        from mvit.utils.checkpoint import load_checkpoint
        cfg = get_cfg()
        # Load config from cfg.
        if self.args.mvit_config is not None:
            cfg.merge_from_file(self.args.mvit_config)
        self.mvit_model = build_model(cfg)
        load_checkpoint(self.args.mvit_ckpt, self.mvit_model, data_parallel=False)
        self.mvit_model.to(self.device)

    def PatchContra(self, feat_q, feat_k):
        batchSize = feat_q.shape[0]
        dim = feat_q.shape[1]
        feat_k = feat_k.detach()
        feat_q_norm = feat_q.clone().norm(p=2,dim=-1,keepdim=True)
        feat_k_norm = feat_k.clone().norm(p=2,dim=-1,keepdim=True)
        # pos logit
        l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
        l_pos_norm = feat_k_norm*feat_q_norm
        l_pos = l_pos.view(batchSize, 1) / l_pos_norm

        batch_dim_for_bmm = self.batch

        # reshape features to batch size
        feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
        feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
        
        feat_q_norm = feat_q_norm.view(batch_dim_for_bmm, -1, 1)
        feat_k_norm = feat_k_norm.view(batch_dim_for_bmm, -1, 1)
        
        npatches = feat_q.size(1)
        l_neg_norm = torch.bmm(feat_q_norm, feat_k_norm.transpose(2, 1))
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) / l_neg_norm
        

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
        l_neg_curbatch.masked_fill_(diagonal, -10.0)
        l_neg = l_neg_curbatch.view(-1, npatches)

        out = torch.cat((l_pos, l_neg), dim=1) / self.tau

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss

    def forward(self, target_imgs, source_imgs):
        if self.args.dino_ckpt is not None:
            with torch.no_grad():
                src_feats = self.dino_model.get_intermediate_layers(self.preprocess_gen(source_imgs), n=12, return_class_token=True)
            mix_feats = self.dino_model.get_intermediate_layers(self.preprocess_gen(target_imgs), n=12, return_class_token=True)
        elif self.args.vit_ckpt is not None:
            with torch.no_grad():
                src_feats = [[self.dino_model(self.preprocess_gen(source_imgs))[:, 1:, :]]]
            mix_feats = [[self.dino_model(self.preprocess_gen(target_imgs))[:, 1:, :]]]
        elif self.args.mvit_ckpt is not None:
            with torch.no_grad():
                src_feats = [[self.mvit_model.get_intermediate_layers(self.preprocess_gen(source_imgs))]]
            mix_feats = [[self.mvit_model.get_intermediate_layers(self.preprocess_gen(target_imgs))]]
        
        loss = 0
        depth = [0]
        for i in depth:
            src_feat = src_feats[i][0]
            mix_feat = mix_feats[i][0]
            num_patch = src_feat.shape[1]
            src_list = []
            mix_list = []
            for i in range(num_patch):
                src_list.append(src_feat[:, i, :])
                mix_list.append(mix_feat[:, i, :])
            src_feat = torch.cat(src_list, dim=0)
            mix_feat = torch.cat(mix_list, dim=0)
            loss += self.PatchContra(mix_feat, src_feat).mean()

        return loss / len(depth) / self.batch