# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
## added
from robust.adv_evaluate import * #evaluate_PGD, evaluate_AA
import robust.robust_finetune
import pdb, os
mammoth_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import logging
import copy
import clip
import math
import  sys
from argparse import Namespace
from typing import Tuple

import torch
from dataset import get_dataset
from dataset.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel

from utils.loggers import *
from utils.status import ProgressBar
from utils.metrics import backward_transfer, forward_transfer, forgetting, learning_accuracy


def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None:
    """
    Given the output tensor, the dataset at hand and the current task,
    masks the former by setting the responses for the other tasks at -inf.
    It is used to obtain the results for the task-il setting.
    :param outputs: the output tensor
    :param dataset: the continual dataset
    :param k: the task index
    """
    outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf')
    outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK:
               dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf')

def evaluate(model: ContinualModel, dataset: ContinualDataset, texts=None, last=False) -> Tuple[list, list]:
    """
    Evaluates the accuracy of the model for each past task.
    :param model: the model to be evaluated
    :param dataset: the continual dataset at hand
    :return: a tuple of lists, containing the class-ils
             and task-il accuracy for each task
    """
    status = model.net.training
    model.net.eval()
    accs, accs_mask_classes = [], []
    for k, test_loader in enumerate(dataset.test_loaders):
        if last and k < len(dataset.test_loaders) - 1:
            continue
        correct, correct_mask_classes, total = 0.0, 0.0, 0.0
        for data in test_loader:
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(model.device), labels.to(model.device)
                if texts == None:
                    if 'class-il' not in model.COMPATIBILITY:
                        outputs = model(inputs, k)
                    else:
                        outputs = model(inputs)
                else:
                    if 'class-il' not in model.COMPATIBILITY:
                        outputs = model(inputs, texts, k)
                    else:
                        outputs = model(inputs, texts)                    

                _, pred = torch.max(outputs.data, 1)
                correct += torch.sum(pred == labels).item()
                total += labels.shape[0]

                if dataset.SETTING == 'class-il':
                    mask_classes(outputs, dataset, k)
                    _, pred = torch.max(outputs.data, 1)
                    correct_mask_classes += torch.sum(pred == labels).item()

        accs.append(correct / total * 100
                    if 'class-il' in model.COMPATIBILITY else 0)
        accs_mask_classes.append(correct_mask_classes / total * 100)

    model.net.train(status)
    return accs, accs_mask_classes


def train(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace) -> None:
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    if not args.disable_log:
        # logger = Logger(dataset.SETTING, dataset.NAME, model.NAME)
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.INFO)
        log_dir = './save/loggers/{}/'.format(args.dataset)
        filename = '{}_{}_{}_{}'.format(args.wandb_name, args.wandb_tags, args.model_type, args.architecture.replace('/', ''))
        os.makedirs(log_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(log_dir,f'{filename}.log'))
        file_handler.setLevel(logging.INFO)
        formatter = logging.Formatter("%(asctime)s [%(filename)s] => %(message)s")
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    print(args)
    args_dict = vars(args)
    for key, value in args_dict.items():
        print(f'{key}: {value}')
        logger.info(f'{key}: {value}')

    model.net.to(model.device)
    results, results_adv, results_mask_classes = [], [], []
    total_grad, total_curvature = [], []


    progress_bar = ProgressBar(verbose=not args.non_verbose)
    print(file=sys.stderr)
    total_epoch = 0
    MODELS = []

    if args.model_type == 'clip':
        _clip_train(args, dataset, model, results, results_mask_classes, progress_bar, MODELS,
                results_adv, total_grad, total_curvature, logger)
    else:
        for t in range(dataset.N_TASKS):
            num_class = (t + 1) * dataset.N_CLASSES_PER_TASK

            model.net.train()
            train_loader, test_loader = dataset.get_data_loaders()
            if hasattr(model, 'begin_task'):
                model.begin_task(dataset)
            if t and not args.ignore_other_metrics:
                accs = evaluate(model, dataset, last=True)
                results[t-1] = results[t-1] + accs[0]
                if dataset.SETTING == 'class-il':
                    results_mask_classes[t-1] = results_mask_classes[t-1] + accs[1]

            ##### revised #####################################################
            real_epochs = model.args.n_epochs
            if args.robust_method == 'AT' or args.robust_method == 'BOTH':
                if not hasattr(dataset, "get_robust_scheduler"):
                    print("Warning : No Robust Scheduler for this datset")
                    scheduler = dataset.get_scheduler(model, args)
                else :
                    scheduler = dataset.get_robust_scheduler(model, args)
                    real_epochs = int(model.args.n_epochs * math.exp(-t/30))
                    if real_epochs <= 1:
                        real_epochs = 1
            else:
                scheduler = dataset.get_scheduler(model, args)
            ####################################################################

            for epoch in range(model.args.n_epochs):
                if args.model == 'joint':
                    continue
                for i, data in enumerate(train_loader):
                    if args.debug_mode and i > 3:
                        break
                    if hasattr(dataset.train_loader.dataset, 'logits'):
                        inputs, labels, not_aug_inputs, logits = data
                        inputs = inputs.to(model.device)
                        labels = labels.to(model.device)
                        not_aug_inputs = not_aug_inputs.to(model.device)
                        logits = logits.to(model.device)
                        loss = model.meta_observe(inputs, labels, not_aug_inputs, logits, num_class, epoch)
                    else:
                        inputs, labels, not_aug_inputs = data
                        inputs, labels = inputs.to(model.device), labels.to(
                            model.device)
                        not_aug_inputs = not_aug_inputs.to(model.device)
                        loss = model.meta_observe(inputs, labels, not_aug_inputs, num_class, epoch)
                    assert not math.isnan(loss)
                    progress_bar.prog(i, len(train_loader), epoch, t, loss)
                    
                if scheduler is not None:
                    scheduler.step()

                
                

            ### For finetune added
            if args.robust_method == 'Finetune' or args.robust_method == 'BOTH':
                pass
                #getattr(methods, args.method)(model.net, robust_epochs, train_loader, args )
            ###


            if hasattr(model, 'end_task'):
                model.end_task(dataset)

            MODELS.append(copy.deepcopy(model))
            # print("test_eps:",args.test_eps,"test_steps:",args.test_steps)
            accs_logit, accs_adv, accs_adv_mask_classes, per_class_output, per_class_output_adv, feat_distance = evaluate_PGD(model, dataset, 
                                                                                                                            args.test_eps, args.test_alpha, args.test_steps)
            accs = evaluate(model, dataset)

            results.append(accs_logit)
            results_adv.append(accs_adv)
            results_mask_classes.append(accs[1])
            mean_acc = np.mean(accs, axis=1)
            mean_acc_adv = round(np.mean(accs_adv), 2)
            mean_acc_logit = round(np.mean(accs_logit), 2)
            gradient_norms, curvatures = evaluate_curvature_input(model, dataset)
            total_grad.append(gradient_norms)
            total_curvature.append(curvatures)
            mean_curvature = np.mean(curvatures)
            mean_gradient_norm = np.mean(gradient_norms)
            print_mean_accuracy(accs, [accs_adv, accs_adv_mask_classes], t + 1, dataset.SETTING, logger)
            print("mean_accs_logit = {}, mean_accs_adv  = {}\n".format(mean_acc_logit, mean_acc_adv))
            logger.info("mean_accs_logit = {}, mean_accs_adv  = {}".format(mean_acc_logit, mean_acc_adv))

        
            if t == dataset.N_TASKS - 1:
                GF_vector, CF_vector = evaluate_gf_cf(model,dataset,args = args, MODELS = MODELS)

                if not args.nowand:
                    d2 = {'GF_vector': GF_vector, 'CF_vector': CF_vector}
                    logger.info(d2)


                accs_adv_AA = evaluate_AA(model, dataset, eps=args.test_eps)
                print(accs_adv_AA)
                if not args.nowand:
                    AA_d = {'RESULT_AutoAttack': accs_adv_AA}
                    logger.info(AA_d)

            if not args.nowand:

                d2={'RESULT_class_mean_accs': mean_acc[0], 'RESULT_task_mean_accs': mean_acc[1], 
                    'RESULT_pgd_mean_accs': mean_acc_adv, 'RESULT_logit_mean_accs': mean_acc_logit, 
                    'RESULT_curvature': mean_curvature, 'RESULT_gradientNorm': mean_gradient_norm,
                    **{f'RESULT_pgd_acc_{i}': a for i, a in enumerate(accs_adv)},
                    **{f'RESULT_class_acc_{i}': a for i, a in enumerate(accs[0])},
                    **{f'RESULT_task_acc_{i}': a for i, a in enumerate(accs[1])}}

                for key, value in d2.items():
                    logger.info(f'{key}: {value}')

    if not args.nowand:
        forget = forgetting(results)
        forget_adv = forgetting(results_adv)
        bwt = backward_transfer(results)
        bwt_adv = backward_transfer(results_adv)
        la = learning_accuracy(results)
        la_adv = learning_accuracy(results_adv)

        grad_forgetting = backward_transfer(total_grad)
        curvature_forgetting = backward_transfer(total_curvature)
        d2={'forgetting': forget, 'forgetting_adv': forget_adv, 'bwt': bwt, 'bwt_adv': bwt_adv, 'la': la, 'la_adv': la_adv, 'GF': grad_forgetting, 'CF' : curvature_forgetting}
        for key, value in d2.items():
            logger.info(f'{key}: {value}')
            print(f'{key}: {value}')


def _clip_train(args, dataset, model, results, results_mask_classes, progress_bar, MODELS,
                results_adv, total_grad, total_curvature, logger):
    for t in range(dataset.N_TASKS):
        num_class = (t + 1) * dataset.N_CLASSES_PER_TASK

        # model.net.visual.train()
        model.net.train()
        train_loader, test_loader = dataset.get_data_loaders()
        texts = clip.tokenize(get_text_prompts_train(args, train_loader.dataset, template=model.template)).to(model.device)
        with torch.no_grad():
            if args.model == 'proof':
                text_features = model.net.convnet.encode_text(texts)
            elif args.model == 'attriclip':
                text_features = texts
            else:
                text_features = model.net.encode_text(texts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                
        
        if hasattr(model, 'begin_task'):
            # model.begin_task(text_features)
            model.begin_task(text_features, dataset)
        if t and not args.ignore_other_metrics:
            accs = evaluate(model, dataset, text_features, last=True)
            results[t-1] = results[t-1] + accs[0]
            if dataset.SETTING == 'class-il':
                results_mask_classes[t-1] = results_mask_classes[t-1] + accs[1]

        ##### revised #####################################################
        real_epochs = model.args.n_epochs
        if args.robust_method == 'AT' or args.robust_method == 'BOTH':
            if not hasattr(dataset, "get_robust_scheduler"):
                print("Warning : No Robust Scheduler for this datset")
                scheduler = dataset.get_scheduler(model, args)
            else :
                scheduler = dataset.get_robust_scheduler(model, args)
                real_epochs = int(model.args.n_epochs * math.exp(-t/30))
                if real_epochs <= 1:
                    real_epochs = 1
        else:
            scheduler = dataset.get_scheduler(model, args)


        for epoch in range(model.args.n_epochs):
            if args.model == 'joint':
                continue
            args.batch_id = -1
            for i, data in enumerate(train_loader):
                if args.debug_mode and i > 3:
                    break
                if hasattr(dataset.train_loader.dataset, 'logits'):
                    inputs, labels, not_aug_inputs, logits = data
                    inputs = inputs.to(model.device)
                    labels = labels.to(model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    logits = logits.to(model.device)
                    loss = model.meta_observe(inputs, labels, not_aug_inputs, logits, num_class, epoch)
                else:
                    inputs, labels, not_aug_inputs = data
                    inputs, labels = inputs.to(model.device), labels.to(
                        model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    loss = model.meta_observe(inputs, labels, not_aug_inputs, num_class, epoch, text_features, texts)
                assert not math.isnan(loss)
                progress_bar.prog(i, len(train_loader), epoch, t, loss)

            if scheduler is not None:
                try:
                    scheduler.step()
                except TypeError:
                    pass
            
            
            

        ### For finetune added
        if args.robust_method == 'Finetune' or args.robust_method == 'BOTH':
            pass



        if hasattr(model, 'end_task'):
            model.end_task(dataset, texts)
    
        accs_logit, accs_adv, accs_adv_mask_classes, per_class_output, per_class_output_adv, feat_distance = evaluate_PGD(model, dataset, args.test_eps, args.test_alpha, args.test_steps,
                                                                                                                        text_features, texts)
        accs = evaluate(model, dataset, text_features)

        results.append(accs_logit)
        results_adv.append(accs_adv)
        results_mask_classes.append(accs[1])
        mean_acc = np.mean(accs, axis=1)
        mean_acc_adv = round(np.mean(accs_adv), 2)
        mean_acc_logit = round(np.mean(accs_logit), 2)
        gradient_norms, curvatures = evaluate_curvature_input(model, dataset, texts=text_features)
        total_grad.append(gradient_norms)
        total_curvature.append(curvatures)
        mean_curvature = np.mean(curvatures)
        mean_gradient_norm = np.mean(gradient_norms)
        print_mean_accuracy(accs, [accs_adv, accs_adv_mask_classes], t + 1, dataset.SETTING, logger)
        print("mean_accs_logit = {}, mean_accs_adv  = {}\n".format(mean_acc_logit, mean_acc_adv))
        logger.info("mean_accs_logit = {}, mean_accs_adv  = {}".format(mean_acc_logit, mean_acc_adv))
        torch.cuda.empty_cache()
    

        if not args.nowand:
            d2={'RESULT_class_mean_accs': mean_acc[0], 'RESULT_task_mean_accs': mean_acc[1], 
                'RESULT_pgd_mean_accs': mean_acc_adv, 'RESULT_logit_mean_accs': mean_acc_logit, 
                'RESULT_curvature': mean_curvature, 'RESULT_gradientNorm': mean_gradient_norm,
                **{f'RESULT_pgd_acc_{i}': a for i, a in enumerate(accs_adv)},
                **{f'RESULT_class_acc_{i}': a for i, a in enumerate(accs[0])},
                **{f'RESULT_task_acc_{i}': a for i, a in enumerate(accs[1])}}
            for key, value in d2.items():
                logger.info(f'{key}: {value}')
        torch.cuda.empty_cache()
        
        if t == dataset.N_TASKS - 1:
            accs_adv_AA = evaluate_AA(model, dataset, eps=args.test_eps, texts=texts)
            if not args.nowand:
                AA_d = {'RESULT_AutoAttack': accs_adv_AA}
                logger.info(AA_d)
                
def get_text_prompts_train(args, train_dataset, template='This is a photo of a {}'):
    class_names = train_dataset.classes
    class_names = refine_classname(class_names)
    texts_train = [template.format(label) for label in class_names]
    return texts_train

def refine_classname(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ').replace('/', ' ')
    return class_names