# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import torch
import torch.nn.functional as F
from dataset import get_dataset
from torch.distributions import Normal
from models.utils.continual_model import ContinualModel
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from utils.batch_norm import bn_track_stats
from utils.buffer import Buffer, icarl_replay
from clip.model import VisualTransformer
import pdb
from robust.attacks import *
from torch.optim import SGD
from utils.adaptor import *
def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Continual Learning via iCaRL.')

    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)
    return parser


class ContinualCLIP(ContinualModel):
    NAME = 'continual_clip'
    COMPATIBILITY = ['class-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(ContinualCLIP, self).__init__(backbone, loss, args, transform)
        self.dataset = get_dataset(args)

        # Instantiate buffers
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK *
                            self.dataset.N_TASKS).to(self.device)

        self.class_means = None
        self.task = -1
        self.old_net = None
        self.train_eps = args.train_eps
        self.train_alpha = args.train_alpha
        self.train_steps = args.train_steps
        self.template = args.template

    def forward(self, x, text=None):
        pc = self.task * self.dataset.N_CLASSES_PER_TASK
        ac = (self.task + 1) * self.dataset.N_CLASSES_PER_TASK
        image_embed = self.net.encode_image(x, None)
        image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
        output = image_embed[:, 0, :] @ text.t()
        return output[:,:ac]


    def observe(self, inputs, labels, not_aug_inputs, num_class, epoch=None, train_texts=None, text_tokens= None):
        return




    def robust_observe(self, inputs, labels, not_aug_inputs, num_class, epoch=None, train_texts=None, text_tokens= None):
        return


    def begin_task(self, text_features, dataset):
        self.task += 1



    def end_task(self, dataset, train_texts=None) -> None:
        self.old_net = deepcopy(self.net.eval())
        self.net.train()
        self.class_means = None

    
    def parameters(self, args):
        if args.model_type != 'clip':
            return self.net.parameters()
        else:
            return list(self.net.visual.transformer.resblocks[-args.last_num_ft:].parameters()) + \
                    list(self.net.visual.ln_post.parameters())