# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import torch
import torch.nn.functional as F
from dataset import get_dataset
from torch.distributions import Normal
from models.utils.continual_model import ContinualModel
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from utils.batch_norm import bn_track_stats
from utils.buffer import Buffer, icarl_replay, fill_buffer_tinyimg
from clip.model import VisualTransformer
import pdb
from robust.attacks import *
from torch.optim import SGD
from utils.adaptor import *
import random
from backbone.rapf_clip import load_model, sample
from tqdm import tqdm


def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Continual Learning via RAPF.')

    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)
    parser.add_argument('--sample_beta', default=2, type=int, required=False)
    parser.add_argument('--shrinkage', default=False, type=bool, required=False)
    parser.add_argument('--mix_bias', default=0.6, type=float, required=False)
    parser.add_argument('--rapf_threshold', default=0.55, type=float, required=False)
    return parser



def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int, train_texts=None) -> None:
    """
    Adds examples from the current task to the memory buffer
    by means of the herding strategy.
    :param mem_buffer: the memory buffer
    :param dataset: the dataset from which take the examples
    :param t_idx: the task index
    """
    if self.args.dataset == 'seq-tinyimg':
        fill_buffer_tinyimg(self, mem_buffer, dataset, t_idx, train_texts)
        return
    mode = self.net.training
    self.net.eval()
    samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)
    if t_idx > 0:
        # 1) First, subsample prior classes
        buf_x, buf_y, buf_l = self.buffer.get_all_data()

        mem_buffer.empty()
        for _y in buf_y.unique():
            idx = (buf_y == _y)
            _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
            mem_buffer.add_data(
                examples=_y_x[:samples_per_class],
                labels=_y_y[:samples_per_class],
                logits=_y_l[:samples_per_class]
            )

    # 2) Then, fill with current tasks
    loader = dataset.train_loader
    classes_start, classes_end = t_idx * dataset.N_CLASSES_PER_TASK, (t_idx + 1) * dataset.N_CLASSES_PER_TASK

    # 2.1 Extract all features
    a_x, a_y, a_f, a_l = [], [], [], []
    for x, y, not_norm_x in loader:
        mask = (y >= classes_start) & (y < classes_end)
        x, y, not_norm_x = x[mask], y[mask], not_norm_x[mask]
        if not x.size(0):
            continue
        x, y, not_norm_x = (a.to(self.device) for a in (x, y, not_norm_x))
        a_x.append(not_norm_x.to('cpu'))
        a_y.append(y.to('cpu'))
        if train_texts == None:
            feats = self.net(not_norm_x, returnt='features')
            outs = self.net.classifier(feats)
        else:
            not_norm_x_224 = torch.nn.functional.interpolate(not_norm_x, size=(224, 224), mode='bicubic')
            ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
            outs, feats, _, _ = self.net(not_norm_x_224, train_texts = train_texts)
        a_f.append(feats.cpu())
        a_l.append(torch.sigmoid(outs).cpu())
    a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(a_f), torch.cat(a_l)

    # 2.2 Compute class means
    for _y in a_y.unique():
        idx = (a_y == _y)
        _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
        feats = a_f[idx]
        mean_feat = feats.mean(0, keepdim=True)

        running_sum = torch.zeros_like(mean_feat)
        i = 0
        while i < samples_per_class and i < feats.shape[0]:
            cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

            idx_min = cost.argmin().item()

            mem_buffer.add_data(
                examples=_x[idx_min:idx_min + 1].to(self.device),
                labels=_y[idx_min:idx_min + 1].to(self.device),
                logits=_l[idx_min:idx_min + 1].to(self.device)
            )

            running_sum += feats[idx_min:idx_min + 1]
            feats[idx_min] = feats[idx_min] + 1e6
            i += 1

    assert len(mem_buffer.examples) <= mem_buffer.buffer_size
    assert mem_buffer.num_seen_examples <= mem_buffer.buffer_size

    self.net.train(mode)


class RAPF(ContinualModel):
    NAME = 'rapf'
    COMPATIBILITY = ['class-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(RAPF, self).__init__(backbone, loss, args, transform)
        self.dataset = get_dataset(args)

        # Instantiate buffers
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK *
                            self.dataset.N_TASKS).to(self.device)

        self.class_means = None
        self.task = -1
        self.old_net = None
        
        
        self.class_mean_list = []
        self.class_cov_list = []
        self.class_diff = None
        self.nearest_class = None
        self.class_edge_distance = []



        self.train_eps = args.train_eps
        self.train_alpha = args.train_alpha
        self.train_steps = args.train_steps
        self.template = args.template

    def forward(self, x, text=None):
        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        output, _, _, _ = self.net(x, train_texts = text)
        return output[:,:ac]


    def observe(self, inputs, labels, not_aug_inputs, num_class, epoch=None, train_texts=None, text_tokens= None):
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer('classes_so_far', torch.cat((
                self.classes_so_far, labels.to('cpu'))).unique())
        
        self.class_means = None
        if text_tokens != None:
            with torch.no_grad():
                train_texts = self.net.encode_text(text_tokens)
                train_texts = train_texts / train_texts.norm(dim=-1, keepdim=True)

        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        self.args.batch_id += 1
        sg_inputs = None
        edge_sample = None
        targets = labels

        if self.task >0:  
            random_class_order_list = list(range(pc))
            random.shuffle(random_class_order_list)
        
            sg_inputs = []
            sg_targets = []
            # num of classes per batch. Ensure an epoch traverses all classes at least once. 
            # For exemple, if there are 100 classes and 50 batches per epoch , there will be 2 classes per batch.
            if self.args.dataset == "seq-cifar100":
                list_for_one_batch = [random_class_order_list[self.args.batch_id*2%len(random_class_order_list)], random_class_order_list[(self.args.batch_id*2+1)%len(random_class_order_list)]]
            elif self.args.dataset == "seq-cifar10" or self.args.dataset == "seq-stl10":
                list_for_one_batch = [random_class_order_list[self.args.batch_id*1%len(random_class_order_list)]]
            elif self.args.dataset == "seq-tinyimg":
                list_for_one_batch = [random_class_order_list[self.args.batch_id*5%len(random_class_order_list)], random_class_order_list[(self.args.batch_id*5+1)%len(random_class_order_list)], 
                                    random_class_order_list[(self.args.batch_id*5+2)%len(random_class_order_list)], random_class_order_list[(self.args.batch_id*5+3)%len(random_class_order_list)], 
                                    random_class_order_list[(self.args.batch_id*5+4)%len(random_class_order_list)]]

                
            for i in list_for_one_batch: 
                sg_inputs.append(sample(self.class_mean_list[i], self.class_cov_list[i],int(10*self.args.sample_beta), shrink=self.args.shrinkage)) 
                sg_targets.append(torch.ones(int(10*self.args.sample_beta), dtype=torch.long)*i)
            sg_inputs = torch.cat(sg_inputs, dim=0) 
            sg_targets = torch.cat(sg_targets, dim=0)
            targets = torch.cat([targets, sg_targets.to(targets.device)], dim=0)


        if self.hard_pairs is not None and self.hard_pairs.shape[0] > 0:
            edge_sample = []
            edge_p_target = []
            edge_n_target = []
            for hard_pair in self.hard_pairs: 
                edge_sample.append(sample(self.class_mean_list[hard_pair[0]], self.class_cov_list[hard_pair[0]],int(20*self.args.sample_beta), shrink=self.args.shrinkage))
                edge_p_target.append(torch.ones(int(20*self.args.sample_beta), dtype=torch.long).to(hard_pair.device) * hard_pair[0])
                edge_n_target.append(torch.ones(int(20*self.args.sample_beta), dtype=torch.long).to(hard_pair.device) * hard_pair[1])
            edge_sample = torch.cat(edge_sample, dim=0)
            edge_p_target = torch.cat(edge_p_target, dim=0)
            edge_n_target = torch.cat(edge_n_target, dim=0)


        if self.task > 0:
            not_ini = True
        else:
            not_ini = False
        outputs, _, __, edge_sample_features = self.net(inputs, train_texts = train_texts[:ac, :], old_adapter = self.old_net.adapter if hasattr(self.old_net, "adapter") else None,
                                                        memory_data=sg_inputs, not_ini=not_ini, edge_sample=edge_sample, prompt=False)



        if self.task > 0:
            if edge_sample is not None:
                edge_sample_features = edge_sample_features / edge_sample_features.norm(dim=-1, keepdim=True)
                edge_target_features = self.class_name_features[edge_p_target].type(edge_sample_features.dtype)
                edge_target_features = edge_target_features / edge_target_features.norm(dim=-1, keepdim=True)
                edge_nearest_class_features = self.class_name_features[edge_n_target].type(edge_sample_features.dtype)
                edge_nearest_class_features = edge_nearest_class_features / edge_nearest_class_features.norm(dim=-1, keepdim=True)
                loss_hinge = torch.relu(- (edge_sample_features * edge_target_features.clone().detach()).sum(-1) + (edge_sample_features * edge_nearest_class_features.clone().detach()).sum(-1) + 0.1).mean()
        loss_c = torch.nn.functional.cross_entropy(outputs, targets.detach())
        if edge_sample is not None:
            loss = loss_c + loss_hinge
        else:
            loss = loss_c 

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss.item()




    def robust_observe(self, inputs, labels, not_aug_inputs, num_class, epoch=None, train_texts=None, text_tokens= None):
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer('classes_so_far', torch.cat((
                self.classes_so_far, labels.to('cpu'))).unique())
        self.class_means = None
        inputs_adv = PGD(inputs, labels, self, train_texts, text_tokens, eps=self.train_eps, alpha=self.train_alpha, steps=self.train_steps)
        if text_tokens != None:
            with torch.no_grad():
                train_texts = self.net.encode_text(text_tokens)
                train_texts = train_texts / train_texts.norm(dim=-1, keepdim=True)

        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        self.args.batch_id += 1
        sg_inputs = None
        edge_sample = None
        targets = labels

        if self.task >0:  
            random_class_order_list = list(range(pc))
            random.shuffle(random_class_order_list)
        
            sg_inputs = []
            sg_targets = []
            # num of classes per batch. Ensure an epoch traverses all classes at least once. 
            # For exemple, if there are 100 classes and 50 batches per epoch , there will be 2 classes per batch.
            if self.args.dataset == "seq-cifar100":
                list_for_one_batch = [random_class_order_list[self.args.batch_id*2%len(random_class_order_list)], random_class_order_list[(self.args.batch_id*2+1)%len(random_class_order_list)]]
            elif self.args.dataset == "seq-cifar10" or self.args.dataset == "seq-stl10":
                list_for_one_batch = [random_class_order_list[self.args.batch_id*1%len(random_class_order_list)]]
            elif self.args.dataset == "seq-tinyimg":
                list_for_one_batch = [random_class_order_list[self.args.batch_id*5%len(random_class_order_list)], random_class_order_list[(self.args.batch_id*5+1)%len(random_class_order_list)], 
                                    random_class_order_list[(self.args.batch_id*5+2)%len(random_class_order_list)], random_class_order_list[(self.args.batch_id*5+3)%len(random_class_order_list)], 
                                    random_class_order_list[(self.args.batch_id*5+4)%len(random_class_order_list)]]

                
            for i in list_for_one_batch:
                sg_inputs.append(sample(self.class_mean_list[i], self.class_cov_list[i],int(10*self.args.sample_beta), shrink=self.args.shrinkage)) 
                sg_targets.append(torch.ones(int(10*self.args.sample_beta), dtype=torch.long)*i)
            sg_inputs = torch.cat(sg_inputs, dim=0) 
            sg_targets = torch.cat(sg_targets, dim=0)
            targets = torch.cat([targets, sg_targets.to(targets.device)], dim=0)


        if self.hard_pairs is not None and self.hard_pairs.shape[0] > 0:
            edge_sample = []
            edge_p_target = []
            edge_n_target = []
            for hard_pair in self.hard_pairs: 
                edge_sample.append(sample(self.class_mean_list[hard_pair[0]], self.class_cov_list[hard_pair[0]],int(20*self.args.sample_beta), shrink=self.args.shrinkage))
                edge_p_target.append(torch.ones(int(20*self.args.sample_beta), dtype=torch.long).to(hard_pair.device) * hard_pair[0])
                edge_n_target.append(torch.ones(int(20*self.args.sample_beta), dtype=torch.long).to(hard_pair.device) * hard_pair[1])
            edge_sample = torch.cat(edge_sample, dim=0)
            edge_p_target = torch.cat(edge_p_target, dim=0)
            edge_n_target = torch.cat(edge_n_target, dim=0)


        if self.task > 0:
            not_ini = True
        else:
            not_ini = False
        outputs, _, __, edge_sample_features = self.net(inputs_adv, train_texts = train_texts[:ac, :], old_adapter = self.old_net.adapter if hasattr(self.old_net, "adapter") else None,
                                                        memory_data=sg_inputs, not_ini=not_ini, edge_sample=edge_sample, prompt=False)



        if self.task > 0:
            if edge_sample is not None:
                edge_sample_features = edge_sample_features / edge_sample_features.norm(dim=-1, keepdim=True)
                edge_target_features = self.class_name_features[edge_p_target].type(edge_sample_features.dtype)
                edge_target_features = edge_target_features / edge_target_features.norm(dim=-1, keepdim=True)
                edge_nearest_class_features = self.class_name_features[edge_n_target].type(edge_sample_features.dtype)
                edge_nearest_class_features = edge_nearest_class_features / edge_nearest_class_features.norm(dim=-1, keepdim=True)
                loss_hinge = torch.relu(- (edge_sample_features * edge_target_features.clone().detach()).sum(-1) + (edge_sample_features * edge_nearest_class_features.clone().detach()).sum(-1) + 0.1).mean()
        loss_c = torch.nn.functional.cross_entropy(outputs, targets.detach())
        if edge_sample is not None:
            loss = loss_c + loss_hinge
        else:
            loss = loss_c 

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss.item()


    def begin_task(self, text_features, dataset):
        self.task += 1
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        self.adaptation(self.task, text_features = text_features[:ac, :], threshold=self.args.rapf_threshold)
        if self.args.buffer_size > 0:
            icarl_replay(self, dataset)


    def end_task(self, dataset, train_texts=None) -> None:
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        pc = (self.task) * self.dataset.N_CLASSES_PER_TASK
        sample_loader = dataset.train_loader
        sample_data = []
        sample_target = []
        sample_after_adapt_feature = []
        for i, data in enumerate(sample_loader):
            input, target, _ = data
            input, target = input.to(self.net.device), target.to(self.net.device)
            if self.args.robust_method == "AT":
                input_adv = PGD(input, target, self, train_texts, train_texts, eps=self.train_eps, alpha=self.train_alpha, steps=self.train_steps)
                with torch.no_grad():
                    _, ori_ima_feat, after_adapt_feature = self.net(input_adv, train_texts = train_texts[:ac, :], ori_ima_f=True)
            else:
                with torch.no_grad():
                    _, ori_ima_feat, after_adapt_feature = self.net(input, train_texts = train_texts[:ac, :], ori_ima_f=True)
            sample_data.append(ori_ima_feat)
            sample_target.append(target)
            sample_after_adapt_feature.append(after_adapt_feature)
        sample_target = torch.cat(sample_target, dim=0)
        sample_data = torch.cat(sample_data, dim=0)
        sample_after_adapt_feature = torch.cat(sample_after_adapt_feature, dim=0)
        self.analyze_mean_cov(sample_data, sample_target, pc, ac)
        self.mix_matrix()

        self.old_net = dataset.get_backbone(self.args)
        self.old_net.load_state_dict(self.net.state_dict())
        self.old_net.eval()
        self.net.train()
        if self.args.buffer_size > 0:
            with torch.no_grad():
                fill_buffer(self, self.buffer, dataset, self.task, train_texts)
        self.class_means = None


    def compute_class_means(self) -> None:
        """
        Computes a vector representing mean features for each class.
        """
        # This function caches class means
        #transform = self.dataset.get_normalization_transform()
        class_means = []
        examples, labels, _ = self.buffer.get_all_data()
        for _y in self.classes_so_far:
            x_buf = torch.stack(
                [examples[i]
                 for i in range(0, len(examples))
                 if labels[i].cpu() == _y]
            ).to(self.device)
            with bn_track_stats(self, False):
                allt = None
                while len(x_buf):
                    batch = x_buf[:self.args.batch_size]
                    x_buf = x_buf[self.args.batch_size:]
                    feats = self.net(batch, returnt='features').mean(0)
                    if allt is None:
                        allt = feats
                    else:
                        allt += feats
                        allt /= 2
                class_means.append(allt.flatten())
        self.class_means = torch.stack(class_means)



    def parameters(self, args):
        if args.last_num_ft == -1:
            return self.net.adapter.parameters()
        else:
            return list(self.net.visual.transformer.resblocks[-args.last_num_ft:].parameters()) + list(self.net.visual.ln_post.parameters()) + list(self.net.adapter.parameters())
    
    def get_optimizer(self, args):
        self.opt = torch.optim.SGD(self.parameters(args), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt, [int(args.n_epochs * 0.48), int(args.n_epochs * 0.62), int(args.n_epochs * 0.80)], gamma=0.1, verbose=False)
        return self.scheduler
  


    def adaptation(self, task_id, text_features, threshold=0):
        self.class_name_features = text_features
        self.queue_empty = True
        self.hard_pairs = None
        pc = (self.task) * self.dataset.N_CLASSES_PER_TASK
        if task_id>0:
            dist_list = []
            for k, class_name_feature in enumerate(self.class_name_features[-pc:]):
                diff = torch.cdist(self.class_name_features[:-pc].type(torch.float32), class_name_feature.unsqueeze(0).type(torch.float32)).squeeze()
                dist_list.append(diff)
            dist_list = torch.stack(dist_list)
            self.class_diff = dist_list
            mask = self.class_diff < threshold
            indices = torch.nonzero(mask)
            self.hard_new_class = torch.unique(indices[:,1]) + pc
            num_hard_class = self.hard_new_class.shape[0]
            self.hard_pairs = indices
            self.hard_pairs[:,1] = self.hard_pairs[:,1] + pc   


    def analyze_mean_cov(self, features, labels, pc, ac):
        labels_unique = torch.sort(torch.unique(labels))[0]
        print(len(labels_unique))
        label = labels_unique[(labels_unique >= pc) & (labels_unique < ac)]
        print(len(label))
        for l in label:
            index = torch.nonzero(labels == l)
            index = index.squeeze()
            class_data = features[index]
            mean = class_data.mean(dim=0)
            cov = torch.cov(class_data.t()) + 1e-4* torch.eye(class_data.shape[-1], device=class_data.device)
            distance = torch.cdist(class_data.type(torch.float32), mean.type(torch.float32).unsqueeze(0)).squeeze()
            max_distance = torch.sort(distance)[0][-10:]
            self.class_edge_distance.append((max_distance.mean()-max_distance.min(), max_distance.max() - max_distance.mean(), max_distance.mean()))
            self.class_mean_list.append(mean)
            self.class_cov_list.append(cov)

    def mix_matrix(self):
        if self.old_net is not None:
            weight_new = self.net.adapter.weight.data
            weight_old = self.old_net.adapter.weight.data
            dist = (weight_new - weight_old).abs()
            U_old, S_old, V_old = torch.linalg.svd(weight_old)
            P_new = U_old.T @ weight_new
            dist = (P_new - torch.diag(S_old)@V_old).abs()
            mask = dist / dist.max()
            mask += self.args.mix_bias
            mask = torch.clamp(mask, max=1)
            right = P_new * mask + torch.diag(S_old)@V_old * (1-mask)
            weight = U_old @ right
            self.net.adapter.weight.data = weight
            return