import copy
import time

import torch
import torch.nn as nn
import torch.optim as optim
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

from metric import AverageMeter
from trainer.finetune import FinetuneCLIP, logging


def kl_loc_loss(pre, post, r):
    pre = pre.to(torch.float32)
    post = post.to(torch.float32)

    pre_ = pre.view(-1, pre.shape[-1]) / r
    post_ = post.view(pre_.shape) / r
    assert pre_.shape[0] == post_.shape[0]

    kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum()

    return kl


def ps_distill_loss(pre, post, r=2.0):
    # input image feature, do transpose first
    pre = pre.to(torch.float32).t() / r
    post = post.to(torch.float32).t() / r
    # each row of pre,post corresponds to a text feature's (normalized) similarities to current batch image features
    q = pre.softmax(-1)
    log_p = post.log_softmax(-1)
    loss = torch.sum(-q * log_p, dim=-1).mean()
    return loss


class Distill(FinetuneCLIP):

    def compute_loss(self, batch, model, **kwargs):

        task = kwargs.get('task', 0)
        epoch = kwargs.get('epoch', 0)
        distill_model = kwargs.get('distill_model', None)
        distill_loss_type = kwargs.get('distill_loss_type', 'visual')

        loss_img = nn.CrossEntropyLoss()
        loss_txt = nn.CrossEntropyLoss()
        if distill_loss_type == 'visual':
            dloss = kl_loc_loss
        elif distill_loss_type == 'text':
            dloss = ps_distill_loss
        else:
            raise NotImplementedError

        if task == 0:
            images, label, texts = batch
            images = images.to(self.args.device)
            texts = texts.to(self.args.device)
            distill_loss = 0
        else:
            (images, label, texts), (images_prev, _, texts_prev) = batch
            images = images.to(self.args.device)
            texts = texts.to(self.args.device)
            images_prev = images_prev.to(self.args.device)
            texts_prev = texts_prev.to(self.args.device)
            images_prev = torch.cat([images_prev,images])
            texts_prev = torch.cat([texts_prev,texts])

            with torch.no_grad():
                image_features_prev_distill, _ = distill_model(images_prev, texts_prev)
            image_features_prev, _ = model(images_prev, texts_prev)
            distill_loss = dloss(image_features_prev, image_features_prev_distill, r=self.args.tem)

        ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
        logits_per_image, logits_per_text = model(images, texts)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text,
                                                                          ground_truth)) / 2
        original_loss = total_loss.detach().clone()

        if task > 0:
            total_loss += self.args.scale * distill_loss
        # print(total_loss.item(), original_loss.item(), distill_loss, self.args.scale)
        return total_loss, distill_loss

    def train(self, model, dataset, task):
        train_dataloader, buffer_loader, validset, total_batches = self.get_iterator(dataset, task)

        if self.args.optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=self.args.lr, betas=(0.9, 0.98), eps=1e-6,
                                   weight_decay=self.args.wd)
        elif self.args.optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=self.args.lr,
                                  weight_decay=self.args.wd)
        elif self.args.optimizer == 'adamw':
            optimizer = optim.AdamW(model.parameters(), lr=self.args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.2)
            if not self.args.no_scheduler:
                self.lr_scheduler = CosineAnnealingWarmupRestarts(
                    optimizer,
                    first_cycle_steps=self.args.epochs * 10,
                    cycle_mult=1.0,
                    max_lr=self.args.lr,
                    min_lr=0,
                    warmup_steps=1
                )
        else:
            raise NotImplementedError

        self.unfreeze_model(model)
        batch_time = AverageMeter()
        loss = AverageMeter()
        dloss = AverageMeter()
        best = 0
        stopping_flag = 0

        state_dict = model.state_dict()
        distill_model = copy.deepcopy(model)
        distill_model.eval()
        for param in distill_model.parameters():
            param.requires_grad = False
        optimizer.zero_grad()
        self.compute_importance(train_dataloader, model)

        for epoch in range(self.args.epochs):
            # if self.args.debug:
            #     print(distill_model.visual.transformer.resblocks[-1].mlp.c_fc.bias)

            iterator = train_dataloader

            buffer_iterator = iter(buffer_loader) if buffer_loader else None
            for iiter, batch in enumerate(iterator):
                end = time.time()
                optimizer.zero_grad()

                batch_size = self.get_batch_size(batch, task=task)

                if buffer_iterator:
                    try:
                        batch_b = next(buffer_iterator)
                    except StopIteration:
                        buffer_iterator = iter(buffer_loader)
                        batch_b = next(buffer_iterator)
                    batch = [batch, batch_b]

                total_loss, distill_loss = self.compute_loss(batch, model, task=task, distill_model=distill_model,
                                                             distill_loss_type=self.args.distill_loss)

                total_loss.backward()
                self.update_model(model, optimizer)

                batch_time.update(time.time() - end)
                loss.update(total_loss.item() / batch_size, n=batch_size)
                dloss.update(total_loss.item() / batch_size, n=batch_size)
                logging('iter', iiter + epoch * total_batches, f'train_loss/{task}_distill', loss.val, self.args)
                logging('iter', iiter + epoch * total_batches, f'train_loss/{task}', loss.val, self.args)
                if iiter % self.args.print_frequency == 0:
                    print(' Epoch: [{0}/{1}], Batch: [{2}/{3}]\t'.format(epoch, self.args.epochs, iiter,
                                                                         total_batches),
                          f'Batch Time {batch_time.val: .3f} ({batch_time.avg: .3f})\t'
                          f'Loss {loss.val:.4f} ({loss.avg: .4f}) \t'
                          f'Estimated Task Time {batch_time.avg * total_batches * self.args.epochs / 3600: .3f} H'
                          )

            if (epoch + 1) % self.args.val_frequency == 0:
                model.eval()

                avg = self.middle_evaluation(model, dataset, task, epoch, validset=validset)
                if self.args.optimizer != 'adamw':
                    if avg < best:  # early stop
                        stopping_flag += 1
                        if stopping_flag == self.args.stopping:
                            logging('task', task, f'training epoch', epoch, self.args)
                            break
                    else:
                        best = avg
                        stopping_flag = 0
                        state_dict = {}
                        cur_dict = model.state_dict()
                        for key in cur_dict:
                            state_dict[key] = cur_dict[key].detach().clone()
                    if avg > 99:
                        break
                self.unfreeze_model(model)
            if self.args.optimizer == 'adamw' and not self.args.no_scheduler:
                self.lr_scheduler.step()

        model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.args.optimizer != 'adamw':
            model.load_state_dict(state_dict)
            model.to(torch.device(device))
        print('Update Buffer....')
        dataset.update_buffer(task)


class DistillFFN(Distill):
    def unfreeze_model(self, model):
        model.train()
        for name, param in model.named_parameters():
            if self.args.finetune_proj:
                trainable_params = ('c_proj' in name and 'visual' in name) or name == 'visual.proj'
            else:
                trainable_params = 'c_proj' in name and 'visual' in name
            if trainable_params:

                param.requires_grad = True
            else:
                param.requires_grad = False


class DistillProj(Distill):
    def unfreeze_model(self, model):
        model.train()
        for name, param in model.named_parameters():
            if self.args.finetune_proj:
                trainable_params = ('c_proj' in name and 'visual' in name) or name == 'visual.proj'
            else:
                trainable_params = 'c_proj' in name and 'visual' in name
            if trainable_params:
                if self.args.debug:
                    print(name)
                param.requires_grad = True
            else:
                param.requires_grad = False
