import copy
import time

import torch
import torch.nn as nn
import torch.optim as optim
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

from clip.clip import tokenize
from metric import AverageMeter
from trainer.distill import ps_distill_loss, kl_loc_loss
from trainer.finetune import FinetuneCLIP, logging, FinetuneFFN, FinetuenProj


class CondDistill(FinetuneCLIP):
    def unfreeze_model(self, model):
        model.train()

        model.freeze(text=False)

    def compute_loss(self, batch, model, **kwargs):

        distill_model = kwargs.get('distill_model', None)
        distill_loss_type = kwargs.get('distill_loss_type', 'visual')

        loss_img = nn.CrossEntropyLoss()
        loss_txt = nn.CrossEntropyLoss()
        if distill_loss_type == 'visual':
            dloss = kl_loc_loss
        elif distill_loss_type == 'text':
            dloss = ps_distill_loss
        else:
            raise NotImplementedError

        (images, label, texts), (images_cond, texts_cond) = batch
        images = images.to(self.args.device)
        texts = texts.to(self.args.device)
        images_cond = images_cond.to(self.args.device)
        texts_cond = tokenize(texts_cond, truncate=True).to(self.args.device)

        with torch.no_grad():
            image_features_prev_distill, _ = distill_model(images_cond, texts_cond)
        image_features_prev, _ = model(images_cond, texts_cond)
        distill_loss = dloss(image_features_prev, image_features_prev_distill, r=self.args.tem)

        ground_truth = torch.arange(len(images), dtype=torch.long, device=self.args.device)
        logits_per_image, logits_per_text = model(images, texts)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text,
                                                                          ground_truth)) / 2

        return total_loss + distill_loss * self.args.scale

    def get_batch_size(self, batch, **kwargs):
        return batch[0][0].size(0)

    def train(self, model, dataset, task):
        train_dataloader, buffer_dataloader, validset, total_batches = self.get_iterator(dataset, task)

        if self.args.optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=self.args.lr, betas=(0.9, 0.98), eps=1e-6,
                                   weight_decay=self.args.wd)
        elif self.args.optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=self.args.lr,
                                  weight_decay=self.args.wd)
        elif self.args.optimizer == 'adamw':
            optimizer = optim.AdamW(model.parameters(), lr=self.args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.2)
            if not self.args.no_scheduler:
                self.lr_scheduler = CosineAnnealingWarmupRestarts(
                    optimizer,
                    first_cycle_steps=self.args.epochs * 10,
                    cycle_mult=1.0,
                    max_lr=self.args.lr,
                    min_lr=0,
                    warmup_steps=1
                )
        else:
            raise NotImplementedError

        self.unfreeze_model(model)
        batch_time = AverageMeter()
        loss = AverageMeter()

        state_dict = model.state_dict()
        distill_model = copy.deepcopy(model)
        distill_model.eval()
        for param in distill_model.parameters():
            param.requires_grad = False
        best = 0
        stopping_flag = 0

        for epoch in range(self.args.epochs):
            # if self.args.debug:
            #     print(distill_model.visual.transformer.resblocks[-1].mlp.c_fc.bias)

            iterator = zip(train_dataloader, buffer_dataloader)
            for iter, (batch) in enumerate(iterator):
                end = time.time()
                optimizer.zero_grad()

                batch_size = self.get_batch_size(batch)
                total_loss = self.compute_loss(batch, model, distill_model=distill_model,
                                               distill_loss_type=self.args.distill_loss)

                total_loss.backward()
                self.update_model(model, optimizer)

                batch_time.update(time.time() - end)
                loss.update(total_loss.item() / batch_size, n=batch_size)
                logging('iter', iter + epoch * total_batches, f'train_loss/{task}', loss.val, self.args)
                if iter % self.args.print_frequency == 0:
                    print(' Epoch: [{0}/{1}], Batch: [{2}/{3}]\t'.format(epoch, self.args.epochs, iter,
                                                                         total_batches),
                          f'Batch Time {batch_time.val: .3f} ({batch_time.avg: .3f})\t'
                          f'Loss {loss.val:.4f} ({loss.avg: .4f}) \t'
                          f'Estimated Task Time {batch_time.avg * total_batches * self.args.epochs / 3600: .3f} H'
                          )

            if (epoch + 1) % self.args.val_frequency == 0:
                model.eval()
                avg = self.middle_evaluation(model, dataset, task, epoch, validset=validset)

                if avg < best:  # early stop
                    stopping_flag += 1
                    if stopping_flag == self.args.stopping:
                        break
                else:
                    best = avg
                    stopping_flag = 0
                    state_dict = {}
                    cur_dict = model.state_dict()
                    for key in cur_dict:
                        state_dict[key] = cur_dict[key].detach().clone()
                if avg > 95:
                    break
                self.unfreeze_model(model)

        model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"

        model.load_state_dict(state_dict)
        model.to(torch.device(device))
        print('Update Buffer....')
        dataset.update_buffer(task)


class CondDistillFFN(CondDistill, FinetuneFFN):
    def unfreeze_model(self, model):
        FinetuneFFN.unfreeze_model(self, model)


class CondDistillProj(CondDistill, FinetuenProj):
    def unfreeze_model(self, model):
        FinetuenProj.unfreeze_model(self, model)
