# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import torch
import torch.nn.functional as F
from dataset import get_dataset
from torch.distributions import Normal
from models.utils.continual_model import ContinualModel
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from utils.batch_norm import bn_track_stats
from utils.buffer import Buffer, icarl_replay
from clip.model import VisualTransformer
import pdb
from robust.attacks import *
from torch.optim import SGD
from utils.adaptor import *
def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Continual Learning via SG.')

    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)
    parser.add_argument('--softness', default=13, type=float, required=False,
                        help='control the softening of generated labels')
    parser.add_argument('--T', default=0.5, type=float, required=False,
                        help='temperature scaling for distillation loss')
    return parser



def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int, train_texts=None) -> None:
    """
    Adds examples from the current task to the memory buffer
    by means of the herding strategy.
    :param mem_buffer: the memory buffer
    :param dataset: the dataset from which take the examples
    :param t_idx: the task index
    """

    mode = self.net.training
    self.net.eval()
    samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)
    if t_idx > 0:
        # 1) First, subsample prior classes
        buf_x, buf_y, buf_l = self.buffer.get_all_data()

        mem_buffer.empty()
        for _y in buf_y.unique():
            idx = (buf_y == _y)
            _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
            mem_buffer.add_data(
                examples=_y_x[:samples_per_class],
                labels=_y_y[:samples_per_class],
                logits=_y_l[:samples_per_class]
            )

    # 2) Then, fill with current tasks
    loader = dataset.train_loader
    classes_start, classes_end = t_idx * dataset.N_CLASSES_PER_TASK, (t_idx + 1) * dataset.N_CLASSES_PER_TASK

    # 2.1 Extract all features
    a_x, a_y, a_f, a_l = [], [], [], []
    for x, y, not_norm_x in loader:
        mask = (y >= classes_start) & (y < classes_end)
        x, y, not_norm_x = x[mask], y[mask], not_norm_x[mask]
        if not x.size(0):
            continue
        x, y, not_norm_x = (a.to(self.device) for a in (x, y, not_norm_x))
        a_x.append(not_norm_x.to('cpu'))
        a_y.append(y.to('cpu'))
        if train_texts == None:
            feats = self.net(not_norm_x, returnt='features')
            outs = self.net.classifier(feats)
        else:
            not_norm_x_224 = torch.nn.functional.interpolate(not_norm_x, size=(224, 224), mode='bicubic')
            features, text_embed = self.net(not_norm_x_224, train_texts)
            feats = features[:, 0, :]
            outs = feats @ text_embed.t()
        a_f.append(feats.cpu())
        a_l.append(torch.sigmoid(outs).cpu())
    a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(a_f), torch.cat(a_l)

    # 2.2 Compute class means
    for _y in a_y.unique():
        idx = (a_y == _y)
        _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
        feats = a_f[idx]
        mean_feat = feats.mean(0, keepdim=True)

        running_sum = torch.zeros_like(mean_feat)
        i = 0
        while i < samples_per_class and i < feats.shape[0]:
            cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

            idx_min = cost.argmin().item()

            mem_buffer.add_data(
                examples=_x[idx_min:idx_min + 1].to(self.device),
                labels=_y[idx_min:idx_min + 1].to(self.device),
                logits=_l[idx_min:idx_min + 1].to(self.device)
            )

            running_sum += feats[idx_min:idx_min + 1]
            feats[idx_min] = feats[idx_min] + 1e6
            i += 1

    assert len(mem_buffer.examples) <= mem_buffer.buffer_size
    assert mem_buffer.num_seen_examples <= mem_buffer.buffer_size

    self.net.train(mode)


class SG(ContinualModel):
    NAME = 'sg'
    COMPATIBILITY = ['class-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(SG, self).__init__(backbone, loss, args, transform)
        self.dataset = get_dataset(args)

        # Instantiate buffers
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK *
                            self.dataset.N_TASKS).to(self.device)

        self.class_means = None
        self.task = -1
        self.old_net = None
        self.train_eps = args.train_eps
        self.train_alpha = args.train_alpha
        self.train_steps = args.train_steps
        self.template = args.template

    def forward(self, x, text=None):
        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        image_embed = self.net.encode_image(x, None)
        image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
        output = image_embed[:, 0, :] @ text.t()
        return output[:,:ac]


    def observe(self, inputs, labels, not_aug_inputs, num_class, epoch=None, train_texts=None, text_tokens= None):
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer('classes_so_far', torch.cat((
                self.classes_so_far, labels.to('cpu'))).unique())
        
        self.class_means = None
        if text_tokens != None:
            with torch.no_grad():
                train_texts = self.net.encode_text(text_tokens)
                train_texts = train_texts / train_texts.norm(dim=-1, keepdim=True)

        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK

        cur_text_features = train_texts[pc:ac]
        prev_text_features = train_texts[:pc]

        image_features = self.net.encode_image(inputs, None)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_features = image_features[:, 0, :]
        text_features = train_texts[labels]

        logit_scale = self.net.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()
        targets = torch.arange(len(logits_per_image)).to(self.device)
        image_loss = F.cross_entropy(logits_per_image, targets)
        text_loss = F.cross_entropy(logits_per_text, targets)
        contrastive_loss = (image_loss + text_loss) / 2


        pred = logit_scale * image_features @ cur_text_features.t()
        pred = F.log_softmax(pred, dim=1)
        self.soft_labels = self.generate_soft_labels(train_texts[pc:ac], self.args.softness)
        target_distribution = self.make_batch_soft_labels(self.soft_labels, labels-pc, ac-pc, len(pred))
        soft_loss = F.kl_div(pred, target_distribution.to(self.device))

        if self.task > 0:
            with torch.no_grad():
                features_old = self.old_net.encode_image(inputs, None)
                features_old = features_old / features_old.norm(dim=-1, keepdim=True)
                features_old = features_old[:, 0, :]
                inter_task_sim = self.calculate_inter_task_similarity(train_texts[labels], prev_text_features)

            prev_pred = logit_scale * features_old @ prev_text_features.t()
            inter_task_sim = inter_task_sim.type(torch.float16).to(self.device)
            max_values, max_indices = torch.max(inter_task_sim, dim=1)
            max_values, max_indices = max_values.to(self.device), max_indices.to(self.device)
            max_values = max_values.unsqueeze(1)
            prev_pred += torch.where(torch.arange(prev_pred.shape[1]).unsqueeze(0).cuda() == max_indices.unsqueeze(1), max_values, -max_values)
            new_pred = logit_scale * image_features @ prev_text_features.t()
            kd_loss = self.dist_loss(new_pred, prev_pred.to(self.device))
            loss = contrastive_loss + self.args.alpha * soft_loss + self.args.beta * kd_loss
        else:
            loss = contrastive_loss + self.args.alpha * soft_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss.item()




    def robust_observe(self, inputs, labels, not_aug_inputs, num_class, epoch=None, train_texts=None, text_tokens= None):
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer('classes_so_far', torch.cat((
                self.classes_so_far, labels.to('cpu'))).unique())
        self.class_means = None
        inputs_adv = PGD(inputs, labels, self, train_texts, text_tokens, eps=self.train_eps, alpha=self.train_alpha, steps=self.train_steps)
        if text_tokens != None:
            with torch.no_grad():
                train_texts = self.net.encode_text(text_tokens)
                train_texts = train_texts / train_texts.norm(dim=-1, keepdim=True)

        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK

        cur_text_features = train_texts[pc:ac]
        prev_text_features = train_texts[:pc]

        image_features = self.net.encode_image(inputs_adv, None)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_features = image_features[:, 0, :]
        text_features = train_texts[labels]

        logit_scale = self.net.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()
        targets = torch.arange(len(logits_per_image)).to(self.device)
        image_loss = F.cross_entropy(logits_per_image, targets)
        text_loss = F.cross_entropy(logits_per_text, targets)
        contrastive_loss = (image_loss + text_loss) / 2


        pred = logit_scale * image_features @ cur_text_features.t()
        pred = F.log_softmax(pred, dim=1)
        self.soft_labels = self.generate_soft_labels(train_texts[pc:ac], self.args.softness)
        target_distribution = self.make_batch_soft_labels(self.soft_labels, labels-pc, ac-pc, len(pred))
        soft_loss = F.kl_div(pred, target_distribution.to(self.device))

        if self.task > 0:
            with torch.no_grad():
                features_old = self.old_net.encode_image(inputs_adv, None)
                features_old = features_old / features_old.norm(dim=-1, keepdim=True)
                features_old = features_old[:, 0, :]
                inter_task_sim = self.calculate_inter_task_similarity(train_texts[labels], prev_text_features)

            prev_pred = logit_scale * features_old @ prev_text_features.t()
            inter_task_sim = inter_task_sim.type(torch.HalfTensor).to(self.device)
            max_values, max_indices = torch.max(inter_task_sim, dim=1)
            max_values, max_indices = max_values.to(self.device), max_indices.to(self.device)
            max_values = max_values.unsqueeze(1)
            prev_pred += torch.where(torch.arange(prev_pred.shape[1]).unsqueeze(0).cuda() == max_indices.unsqueeze(1), max_values, -max_values)
            new_pred = logit_scale * image_features @ prev_text_features.t()
            kd_loss = self.dist_loss(new_pred, prev_pred.to(self.device))
            loss = contrastive_loss + self.args.alpha * soft_loss + self.args.beta * kd_loss
        else:
            loss = contrastive_loss + self.args.alpha * soft_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss.item()


    def begin_task(self, text_features, dataset):
        self.task += 1




    def end_task(self, dataset, train_texts=None) -> None:
        self.old_net = deepcopy(self.net.eval())
        self.net.train()
        self.class_means = None


    def compute_class_means(self) -> None:
        """
        Computes a vector representing mean features for each class.
        """
        # This function caches class means
        #transform = self.dataset.get_normalization_transform()
        class_means = []
        examples, labels, _ = self.buffer.get_all_data() #transform)
        for _y in self.classes_so_far:
            x_buf = torch.stack(
                [examples[i]
                 for i in range(0, len(examples))
                 if labels[i].cpu() == _y]
            ).to(self.device)
            with bn_track_stats(self, False):
                allt = None
                while len(x_buf):
                    batch = x_buf[:self.args.batch_size]
                    x_buf = x_buf[self.args.batch_size:]
                    feats = self.net(batch, returnt='features').mean(0)
                    if allt is None:
                        allt = feats
                    else:
                        allt += feats
                        allt /= 2
                class_means.append(allt.flatten())
        self.class_means = torch.stack(class_means)



    def parameters(self, args):
        return list(self.net.visual.transformer.resblocks[-args.last_num_ft:].parameters())
    
    def get_optimizer(self, args):
        self.opt = torch.optim.SGD(self.parameters(args), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom)
        self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.opt, factor=1.0)
        return self.scheduler
    
    @torch.no_grad()
    def generate_soft_labels(self, text_features, softness):
        sim_matrix = text_features @ text_features.t()
        sim_matrix = sim_matrix.float()
        soft_labels = torch.exp(softness * sim_matrix) / torch.sum(torch.exp(softness * sim_matrix), dim=0)
        return soft_labels
    
    def calculate_inter_task_similarity(self, batched_text_features, prev_text_features):
        sim_matrix = batched_text_features @ prev_text_features.t()
        return sim_matrix
    
    def make_batch_soft_labels(self, all_soft_labels, target, num_classes, batch_size):
        soft_labels = torch.zeros((batch_size, num_classes), dtype=torch.float16, device=self.device)
        for i in range(batch_size):
            this_label = all_soft_labels[:, target[i]]
            soft_labels[i, :] = this_label
        return soft_labels
    
    def dist_loss(self, y_s, y_t):
        p_s = F.log_softmax(y_s / self.args.T, dim=1)
        p_t = F.softmax(y_t / self.args.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction='sum') * (self.args.T ** 2) / y_s.shape[0]
        return loss