import os.path as osp

import torch
import torch.nn as nn
from torch.nn import functional as F

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

import copy

_tokenizer = _Tokenizer()


CUSTOM_TEMPLATES = {
    'OxfordPets': 'a photo of a {}, a type of pet.',
    'OxfordFlowers': 'a photo of a {}, a type of flower.',
    'FGVCAircraft': 'a photo of a {}, a type of aircraft.',
    'DescribableTextures': '{} texture.',
    'EuroSAT': 'a centered satellite photo of {}.',
    'StanfordCars': 'a photo of a {}.',
    'Food101': 'a photo of {}, a type of food.',
    'SUN397': 'a photo of a {}.',
    'Caltech101': 'a photo of a {}.',
    'UCF101': 'a photo of a person doing {}.',
    'ImageNet': 'a photo of a {}.',
    'ImageNetSketch': 'a photo of a {}.',
    'ImageNetV2': 'a photo of a {}.',
    'ImageNetA': 'a photo of a {}.',
    'ImageNetR': 'a photo of a {}.'
}


def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)
    
    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location='cpu').eval()
        state_dict = None
    
    except RuntimeError:
        state_dict = torch.load(model_path, map_location='cpu')
    
    model = clip.build_model(state_dict or model.state_dict())

    return model

def clip_classifier(classnames, clip_model, template):
    clip_model.eval()
    with torch.no_grad():
        clip_weights = []
        for classname in classnames:
            # Tokenize the prompts
            classname = classname.replace('_', ' ')
            texts = template.format(classname)
            texts = clip.tokenize(texts).cuda()
            # prompt ensemble for ImageNet
            class_embedding = clip_model(texts)
            class_embedding /= class_embedding.norm(dim=-1, keepdim=True)
            clip_weights.append(class_embedding)
        clip_weights = torch.concat(clip_weights, dim=0).cuda()
    return clip_weights


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.vocab_size = clip_model.vocab_size
        self.token_embedding = clip_model.token_embedding
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
    
    def forward(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return x        
    
   
class KDLoss(nn.Module):
    def __init__(self, temp_factor):
        super(KDLoss, self).__init__()
        self.temp_factor = temp_factor
        self.kl_div = nn.KLDivLoss(reduction="sum")

    def forward(self, input, target):
        log_p = torch.log_softmax(input / self.temp_factor, dim=1)
        q = torch.softmax(target / self.temp_factor, dim=1)
        loss = self.kl_div(log_p, q) * (self.temp_factor**2)/input.size(0)
        return loss
    
    
class CustomCLIP(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        
        self.cfg = cfg
        self.classnames = classnames
        
        self.image_encoder = copy.deepcopy(clip_model.visual)
        self.freeze_visual = clip_model.visual

        self.text_encoder = copy.deepcopy(TextEncoder(clip_model))
        self.freeze_language = TextEncoder(clip_model).cuda()

        self.template = CUSTOM_TEMPLATES[self.cfg.DATASET.NAME]
        
        self.clip_weights = clip_classifier(classnames, self.freeze_language, template=self.template).t()
        self.text_classifier = nn.Linear(self.clip_weights.shape[0], self.clip_weights.shape[1], bias=False).cuda()
        # nn.init.xavier_normal_(text_classifier.weight)            
        self.text_classifier.weight = nn.Parameter(self.clip_weights.t(), requires_grad=True)
        
        self.freeze_text_classifier = clip_classifier(classnames, self.freeze_language, template=self.template).t()
        
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.kd_criterion = KDLoss(4.0)

        self.ensemble_model_flag = False
        self.evaluation_vision_encoder = copy.deepcopy(self.freeze_visual)
        self.evaluation_text_encoder = copy.deepcopy(self.freeze_language)

        self.ensemble_classifier_weights = clip_classifier(self.classnames, self.evaluation_text_encoder, template=self.template).t()
        self.ensemble_classifier = nn.Linear(self.ensemble_classifier_weights.shape[0], self.ensemble_classifier_weights.shape[1], bias=False).cuda()
        self.ensemble_classifier.weight = nn.Parameter(self.ensemble_classifier_weights.t(), requires_grad=False)
        
        
    def info_nce_loss(self, vis_features, text_features, labels, tau=0.01, s_text_embeddings=None):
        
        batch_size = vis_features.shape[0]
        
        similarity_matrix = torch.matmul(vis_features, text_features.T) / tau
        
        if s_text_embeddings is not None:
            s_sim = 0
            for s_feat in s_text_embeddings:
                s_sim += 100 * torch.matmul(vis_features, s_feat.T) 
            s_sim /= len(s_text_embeddings)
            similarity_matrix =  similarity_matrix * 0.6 + s_sim * 0.4
            
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()

        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().cuda()
        
        # logits_mask = torch.scatter(
        #     torch.ones_like(mask),
        #     1,
        #     torch.arange(batch_size * 1).view(-1, 1).cuda(),
        #     0
        # )
        # mask = mask * logits_mask
    
        exp_logits = torch.exp(logits) # * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        loss = -mean_log_prob_pos
        loss = loss.view(1, batch_size).mean()
        return loss

    def ensemble_model(self):            
        print("==== Ensemble Model ====")    
        with torch.no_grad():
            theta_0 = {k: v.clone() for k, v in self.image_encoder.state_dict().items()}
            theta_1 = {k: v.clone() for k, v in self.freeze_visual.state_dict().items()} 
            assert set(theta_0.keys()) == set(theta_1.keys())
            alpha = 0.5
            new_theta =  {
                key: (1 - alpha) * theta_0[key] + alpha * theta_1[key]
                for key in theta_0.keys()
            }
            self.evaluation_vision_encoder.load_state_dict(new_theta)
            theta_0 = {k: v.clone() for k, v in self.text_encoder.state_dict().items()}
            theta_1 = {k: v.clone() for k, v in self.freeze_language.state_dict().items()} 
            assert set(theta_0.keys()) == set(theta_1.keys())
            new_theta =  {
                key: (1 - alpha) * theta_0[key] + alpha * theta_1[key]
                for key in theta_0.keys()
            }
            self.evaluation_text_encoder.load_state_dict(new_theta)            
        
        self.ensemble_classifier_weights = clip_classifier(self.classnames, self.evaluation_text_encoder, template=self.template).t()
        self.ensemble_classifier = nn.Linear(self.ensemble_classifier_weights.shape[0], self.ensemble_classifier_weights.shape[1], bias=False).cuda()
        self.ensemble_classifier.weight = nn.Parameter(self.ensemble_classifier_weights.t(), requires_grad=False)
         
    def forward(self, images, labels=None):
        logit_scale = self.logit_scale.exp()

        with torch.no_grad():
            ori_features = self.freeze_visual(images.type(self.dtype))
            
        if labels is not None:
            image_features = self.image_encoder(images.type(self.dtype))         
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            batch_classnames = [self.classnames[i] for i in labels]
            batch_classnames = [self.template.format(name.replace("_", " ")) for name in batch_classnames]            
            batch_texts = clip.tokenize(batch_classnames).cuda()
            batch_text_embedding = self.text_encoder(batch_texts)
            batch_text_embedding = batch_text_embedding / batch_text_embedding.norm(dim=-1, keepdim=True)
    
            logits = logit_scale * self.text_classifier(image_features)
            
            loss_triplet = self.info_nce_loss(image_features, batch_text_embedding, labels) + self.info_nce_loss(batch_text_embedding, image_features, labels)
            
            sim_1 = 10 * image_features @ batch_text_embedding.t().detach()
            sim_2 = 10 * ori_features @ batch_text_embedding.t().detach()
            
            loss_kd =  self.kd_criterion(sim_1, sim_2.detach())
            # loss_kd =  nn.L1Loss()(sim_1, sim_2.detach())
            
            return logits, loss_triplet, loss_kd
        else:
            image_features = self.evaluation_vision_encoder(images.type(self.dtype))         
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logits = logit_scale * self.ensemble_classifier(image_features)
            return logits            
        return logits


@TRAINER_REGISTRY.register()
class SemanticCLIP(TrainerX):
    """ CLIP-Adapter """

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f'Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})')
        clip_model = load_clip_to_cpu(cfg)
        clip_model.float()

        print('Building custom CLIP')
        self.model = CustomCLIP(cfg, classnames, clip_model)
        print('Turning off gradients in both the image and the text encoder')
        
        for name, param in self.model.named_parameters():
            if 'image_encoder' in name or 'text_encoder' in name or 'text_classifier' in name:
                param.requires_grad_(True)
            else:
                param.requires_grad_(False)
        
        trainable_params = []
        for name, param in self.model.named_parameters():
            if param.requires_grad is True:
                trainable_params.append(name)
        # print("[INFO Trainable parameters] : ", trainable_params)                
        print("[INFO] model weights: ", cfg.MODEL.INIT_WEIGHTS)
        
        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)

        self.model.to(self.device)
        # NOTE: only give text_encoder.adapter to the optimizer
        self.optim = build_optimizer(self.model, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)

        self.register_model('semantic_model', self.model, self.optim, self.sched)
        
        device_count = torch.cuda.device_count()
        self.multi_gpu = False
        if device_count > 1:
            self.multi_gpu = True
            print(f'Multiple GPUs detected (n_gpus={device_count}), use all of them!')
            self.model = nn.DataParallel(self.model)

    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)
        output, loss_triplet, loss_kd = self.model(image, label)
        loss_ce = F.cross_entropy(output, label)
        alpha = 0.8
        beta = 0.2 
        if self.epoch > 2:       
            loss = (1 - alpha) * loss_ce + alpha * loss_triplet 
        else:
            loss = (1 - alpha) * loss_ce + alpha * loss_triplet + beta * loss_kd
                
        self.model_backward_and_update(loss.mean())
        loss_summary = {
            'loss': loss.mean().item(),
            'acc': compute_accuracy(output, label)[0].item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch['img']
        label = batch['label']        
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    def after_epoch(self):
        last_epoch = (self.epoch + 1) == self.max_epoch
        do_test = not self.cfg.TEST.NO_TEST
        meet_checkpoint_freq = (
            (self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
            if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
        )
        
        if self.multi_gpu:
            self.model.module.ensemble_model()
        else:
            self.model.ensemble_model()

        if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
            curr_result = self.test(split="val")
            is_best = curr_result > self.best_result
            if is_best:
                self.best_result = curr_result
                self.save_model(
                    self.epoch,
                    self.output_dir,
                    model_name="model-best.pth.tar"
                )

        if meet_checkpoint_freq or last_epoch:
            self.save_model(self.epoch, self.output_dir)

    def after_train(self):
        print("Finished training")
        import time
        import datetime
        do_test = not self.cfg.TEST.NO_TEST
        if do_test:            
            if self.cfg.TEST.FINAL_MODEL == "best_val":
                print("Deploy the model with the best val performance")
                self.load_model(self.output_dir)
            self.test()

        # Show elapsed time
        elapsed = round(time.time() - self.time_start)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        print("Elapsed: {}".format(elapsed))

        # Close writer
        self.close_writer()
       
                    
    def load_model(self, directory, epoch=None):
        if not directory:
            print(
                'Note that load_model() is skipped as no pretrained model is given'
            )
            return

        names = self.get_model_names()

        # By default, the best model is loaded
        model_file = 'model-best.pth.tar'

        if epoch is not None:
            model_file = 'model.pth.tar-' + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError(
                    'Model not found at "{}"'.format(model_path)
                )

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint['state_dict']
            epoch = checkpoint['epoch']
            
            # Ignore fixed token vectors
            if 'token_prefix' in state_dict:
                del state_dict['token_prefix']
            
            if 'token_suffix' in state_dict:
                del state_dict['token_suffix']

            
            del state_dict['text_classifier.weight']
            del state_dict['ensemble_classifier.weight']
    
            
            print(
                'Loading weights to {} '
                'from "{}" (epoch = {})'.format(name, model_path, epoch)
            )
            # set strict=False
            info = self._models[name].load_state_dict(state_dict, strict=False)
            print("[Load info] : ", info)