# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import get_dataset
from models.utils.continual_model import ContinualModel
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from robust.attacks import *
from utils.buffer import Buffer, icarl_replay, fill_buffer_tinyimg

def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Continual learning via online EWC.')
    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)
    parser.add_argument('--e_lambda', type=float, default=25.0,
                        help='lambda weight for EWC')
    return parser

def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int, train_texts=None) -> None:
    """
    Adds examples from the current task to the memory buffer
    by means of the herding strategy.
    :param mem_buffer: the memory buffer
    :param dataset: the dataset from which take the examples
    :param t_idx: the task index
    """
    if self.args.dataset == 'seq-tinyimg':
        fill_buffer_tinyimg(self, mem_buffer, dataset, t_idx, train_texts)
        return
    mode = self.net.training
    self.net.eval()
    samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)
    if t_idx > 0:
        # 1) First, subsample prior classes
        buf_x, buf_y, buf_l = self.buffer.get_all_data()

        mem_buffer.empty()
        for _y in buf_y.unique():
            idx = (buf_y == _y)
            _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
            mem_buffer.add_data(
                examples=_y_x[:samples_per_class],
                labels=_y_y[:samples_per_class],
                logits=_y_l[:samples_per_class]
            )

    # 2) Then, fill with current tasks
    loader = dataset.train_loader
    classes_start, classes_end = t_idx * dataset.N_CLASSES_PER_TASK, (t_idx + 1) * dataset.N_CLASSES_PER_TASK

    # 2.1 Extract all features
    a_x, a_y, a_f, a_l = [], [], [], []
    for x, y, not_norm_x in loader:
        mask = (y >= classes_start) & (y < classes_end)
        x, y, not_norm_x = x[mask], y[mask], not_norm_x[mask]
        if not x.size(0):
            continue
        x, y, not_norm_x = (a.to(self.device) for a in (x, y, not_norm_x))
        a_x.append(not_norm_x.to('cpu'))
        a_y.append(y.to('cpu'))
        if train_texts == None:
            feats = self.net(not_norm_x, returnt='features')
            outs = self.net.classifier(feats)
        else:
            not_norm_x_224 = torch.nn.functional.interpolate(not_norm_x, size=(224, 224), mode='bicubic')
            features, text_embed = self.net(not_norm_x_224, train_texts)
            feats = features[:, 0, :]
            outs = feats @ text_embed.t()
        a_f.append(feats.cpu())
        a_l.append(torch.sigmoid(outs).cpu())
    a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(a_f), torch.cat(a_l)

    # 2.2 Compute class means
    for _y in a_y.unique():
        idx = (a_y == _y)
        _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
        feats = a_f[idx]
        mean_feat = feats.mean(0, keepdim=True)

        running_sum = torch.zeros_like(mean_feat)
        i = 0
        while i < samples_per_class and i < feats.shape[0]:
            cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

            idx_min = cost.argmin().item()

            mem_buffer.add_data(
                examples=_x[idx_min:idx_min + 1].to(self.device),
                labels=_y[idx_min:idx_min + 1].to(self.device),
                logits=_l[idx_min:idx_min + 1].to(self.device)
            )

            running_sum += feats[idx_min:idx_min + 1]
            feats[idx_min] = feats[idx_min] + 1e6
            i += 1

    assert len(mem_buffer.examples) <= mem_buffer.buffer_size
    assert mem_buffer.num_seen_examples <= mem_buffer.buffer_size

    self.net.train(mode)

class EwcOn(ContinualModel):
    NAME = 'ewc_on'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(EwcOn, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.logsoft = nn.LogSoftmax(dim=1)
        self.checkpoint = None
        self.fish = None
        self.task = -1
        self.dataset = get_dataset(args)
        self.train_eps = args.train_eps
        self.train_alpha = args.train_alpha
        self.train_steps = args.train_steps
        self.template = args.template


    def forward(self, x, text=None):
        num_class = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        if text == None:
            return self.net(x)[:,:num_class]
        else:
            image_embed = self.net.encode_image(x, None)
            image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
            output = image_embed[:, 0, :] @ text.t()
            return output[:,:num_class]

    def penalty(self):
        if self.checkpoint is None:
            return torch.tensor(0.0).to(self.device)
        else:
            if self.args.model_type == 'clip':
                params = []
                for pp in list(self.net.visual.parameters()):
                    params.append(pp.view(-1))
                penalty = (self.fish * ((torch.cat(params) - self.checkpoint) ** 2)).sum()
            else:
                penalty = (self.fish * ((self.net.get_params() - self.checkpoint) ** 2)).sum()
            return penalty

    def begin_task(self, text_features, dataset):
        self.task += 1
        if self.args.buffer_size > 0:
            icarl_replay(self, dataset)
        
    def end_task(self, dataset, train_texts=None):
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        if self.args.model_type != 'clip':
            fish = torch.zeros_like(self.net.get_params())  
        else:
            params = []
            for pp in list(self.net.visual.parameters()):
                params.append(pp.view(-1))
            fish = torch.zeros_like(torch.cat(params))
        for j, data in enumerate(dataset.train_loader):
            inputs, labels, _ = data
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            self.opt.zero_grad()
            if self.args.model_type != 'clip':
                output = self.net(inputs)
            else:
                text_features = self.net.encode_text(train_texts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                image_embed = self.net.encode_image(inputs, None)
                image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
                output = image_embed[:, 0, :] @ text_features.t()
            loss = - F.nll_loss(self.logsoft(output[:, :ac]), labels,
                                reduction='none')
            exp_cond_prob = torch.mean(torch.exp(loss.detach().clone()))
            loss = torch.mean(loss)
            loss.backward()
            if self.args.model_type == 'clip':
                grads = []
                for pp in list(self.net.visual.parameters()):
                    grads.append(pp.grad.view(-1))
                fish += exp_cond_prob * torch.cat(grads) ** 2
            else:
                fish += exp_cond_prob * self.net.get_grads() ** 2  

        fish /= (len(dataset.train_loader) * self.args.batch_size)

        if self.fish is None:
            self.fish = fish
        else:
            self.fish *= 1.0
            self.fish += fish

        if self.args.model_type != 'clip':
            self.checkpoint = self.net.get_params().data.clone()  
        else:
            params = []
            for pp in list(self.net.visual.parameters()):
                params.append(pp.view(-1))
            self.checkpoint = torch.cat(params).data.clone()
        if self.args.buffer_size > 0:
            with torch.no_grad():
                fill_buffer(self, self.buffer, dataset, self.task, train_texts)


    def observe(self, inputs, labels, not_aug_inputs, num_class, logits=None, train_texts=None, text_tokens= None):
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer('classes_so_far', torch.cat((
                self.classes_so_far, labels.to('cpu'))).unique())

        self.opt.zero_grad()
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        if text_tokens != None:
            with torch.no_grad():
                train_texts = self.net.encode_text(text_tokens)
                train_texts = train_texts / train_texts.norm(dim=-1, keepdim=True)
            image_embed = self.net.encode_image(inputs, None)
            image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
            outputs = image_embed[:, 0, :] @ train_texts.t()
        else:
            outputs = self.net(inputs)
        penalty = self.penalty()
        loss = self.loss(outputs[:, :ac], labels) + self.args.e_lambda * penalty
        assert not torch.isnan(loss)
        loss.backward()
        self.opt.step()

        return loss.item()


    def robust_observe(self, inputs, labels, not_aug_inputs, num_class, logits=None, train_texts=None, text_tokens= None):
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer('classes_so_far', torch.cat((
                self.classes_so_far, labels.to('cpu'))).unique())

        self.opt.zero_grad()
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        inputs_adv = PGD(inputs, labels, self, train_texts, text_tokens, eps=self.train_eps, alpha=self.train_alpha, steps=self.train_steps)
        if text_tokens != None:
            with torch.no_grad():
                train_texts = self.net.encode_text(text_tokens)
                train_texts = train_texts / train_texts.norm(dim=-1, keepdim=True)
            image_embed = self.net.encode_image(inputs_adv, None)
            image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
            outputs_adv = image_embed[:, 0, :] @ train_texts.t()

        else:
            outputs_adv = self.net(inputs_adv)
        penalty = self.penalty()
        loss = self.loss(outputs_adv[:, :ac], labels) + self.args.e_lambda * penalty
        assert not torch.isnan(loss)
        loss.backward()
        self.opt.step()

        return loss.item()
    
    def parameters(self, args):
        if args.model_type != 'clip':
            return self.net.parameters()
        else:
            return list(self.net.visual.transformer.resblocks[-args.last_num_ft:].parameters()) + \
                            list(self.net.visual.ln_post.parameters())