

import functools
import torchattacks
import pdb, os
import numpy as np
from pathlib import Path


import math
import sys
from argparse import Namespace
from typing import Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import get_dataset
from dataset.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
from sklearn.metrics import confusion_matrix
from robust.attacks import *
from robust.autoattack import AutoAttack


def save_confusion_matrix(all_labels, all_preds, num_classes, save_path="confusion_matrix.png"):
    cm = confusion_matrix(all_labels, all_preds, labels=np.arange(num_classes))
    cm = cm.astype(np.float32)
    cm = cm / (cm.sum(axis=1, keepdims=True) + 1e-8)

    plt.figure(figsize=(6, 6))
    plt.imshow(cm, interpolation='nearest', cmap='viridis', vmin=0, vmax=1)
    cbar = plt.colorbar(fraction=0.0465, pad=0.02)  
    cbar.ax.tick_params(labelsize=20) 

    for tick in cbar.ax.get_yticklabels():
        tick.set_fontname("Times New Roman")

    step = max(1, num_classes // 10)
    tick_positions = np.arange(0, num_classes, step)
    plt.xticks(tick_positions, labels=tick_positions, fontsize=20, fontname="Times New Roman")
    plt.yticks(tick_positions, labels=tick_positions, fontsize=20, fontname="Times New Roman")

    plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"✅ Confusion matrix saved to {save_path}")

def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None:
    """
    Given the output tensor, the dataset at hand and the current task,
    masks the former by setting the responses for the other tasks at -inf.
    It is used to obtain the results for the task-il setting.
    :param outputs: the output tensor
    :param dataset: the continual dataset
    :param k: the task index
    """
    outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf')
    outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK:
               dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf')



def evaluate_PGD(model: ContinualModel, dataset: ContinualDataset, eps, alpha, steps, texts=None, text_tokens=None) :
    status = model.net.training
    model.net.eval()
    accs, accs_adv = [], []
    accs_mask_classes = []
    num_total_class = dataset.N_TASKS * dataset.N_CLASSES_PER_TASK
    per_class_output = np.zeros((num_total_class, num_total_class))
    per_class_output_adv = np.zeros((num_total_class, num_total_class))

    feat_distance = np.zeros((num_total_class))


    all_labels, all_preds, all_preds_adv = [], [], []
    for k, test_loader in enumerate(dataset.test_loaders):
        correct, correct_mask_classes, total = 0.0, 0.0, 0.0
        correct_adv = 0.0
        for data in test_loader:

            inputs, labels = data
            inputs, labels = inputs.to(model.device), labels.to(model.device)

            inputs_adv = PGD(inputs, labels, model, train_texts=texts, text_tokens=text_tokens, eps=eps, alpha=alpha, steps=steps)
            
            model.eval()

            with torch.no_grad():
                if 'class-il' not in model.COMPATIBILITY:
                    if texts == None:
                        outputs = model(inputs, k)
                        outputs_adv = model(inputs_adv, k)
                    else:
                        outputs = model(inputs, texts, k)
                        outputs_adv = model(inputs_adv, texts, k)
                else:
                    if texts == None:
                        outputs = model(inputs)
                        outputs_adv = model(inputs_adv)
                    else:
                        outputs = model(inputs, texts)
                        outputs_adv = model(inputs_adv, texts)
                if texts == None:
                    _, feat = model.net(inputs, returnt='all')
                    _, feat_adv = model.net(inputs_adv, returnt='all')
                else:
                    try:
                        feat = model.net.encode_image(inputs)[:, 0, :] 
                        feat_adv = model.net.encode_image(inputs_adv)[:, 0, :]
                    except:
                        feat = model.net.image_encoder(inputs.half())[:, 0, :]
                        feat_adv = model.net.image_encoder(inputs_adv.half())[:, 0, :]
                _, pred = torch.max(outputs.data, 1)
                correct += torch.sum(pred == labels).item()
                all_labels.append(labels.cpu().numpy())
                all_preds.append(pred.cpu().numpy())
                _, pred_adv = torch.max(outputs_adv.data, 1)
                correct_adv += torch.sum(pred_adv == labels).item()
                all_preds_adv.append(pred_adv.cpu().numpy())
                total += labels.shape[0]
                
                if dataset.SETTING == 'class-il':
                    mask_classes(outputs_adv, dataset, k)
                    _, pred = torch.max(outputs_adv.data, 1)
                    correct_mask_classes += torch.sum(pred == labels).item()
                for i in range(num_total_class):
                    if len(feat[labels==i])>0:
                        feat_distance[i] += torch.norm((feat[labels==i]-feat_adv[labels==i]), dim=1).sum().item()
                    for j in range(num_total_class):
                        per_class_output[i][j] += ((labels==i)&(pred==j)).sum().item()
                        per_class_output_adv[i][j] += ((labels==i)&(pred_adv==j)).sum().item()
    

        accs.append(correct/total * 100
                    if 'class-il' in model.COMPATIBILITY else 0)
        accs_mask_classes.append(correct_mask_classes / total * 100)
        accs_adv.append(correct_adv/total *100.0
                    if 'class-il' in model.COMPATIBILITY else 0)

    per_class_output = per_class_output/10
    per_class_output_adv = per_class_output_adv/10

    feat_distance = feat_distance/1000
    model.net.train(status)
    # num_classes = len(np.unique(np.concatenate(all_labels)))
    # save_confusion_matrix(np.concatenate(all_labels), np.concatenate(all_preds), num_classes=num_classes, save_path=f"cm_clean_{num_classes}classes.png")
    # save_confusion_matrix(np.concatenate(all_labels), np.concatenate(all_preds_adv), num_classes=num_classes, save_path=f"cm_adv_{num_classes}classes.png")
    return accs, accs_adv, accs_mask_classes, per_class_output, per_class_output_adv, feat_distance



def evaluate_FGSM(model: ContinualModel, dataset: ContinualDataset, eps=8/255) :
    status = model.net.training
    model.net.eval()
    accs, accs_adv = [], []

    num_total_class = dataset.N_TASKS * dataset.N_CLASSES_PER_TASK
    per_class_output = np.zeros((num_total_class, num_total_class))
    per_class_output_adv = np.zeros((num_total_class, num_total_class))

    feat_distance = np.zeros((num_total_class))


    for k, test_loader in enumerate(dataset.test_loaders):
        correct, total = 0.0, 0.0
        correct_adv = 0.0
        for data in test_loader:

            inputs, labels = data
            inputs, labels = inputs.to(model.device), labels.to(model.device)

            inputs_adv = FGSM(inputs, labels, model, eps=eps)
            
            model.eval()

            with torch.no_grad():
                if 'class-il' not in model.COMPATIBILITY:
                    outputs = model(inputs, k)
                    outputs_adv = model(inputs_adv, k)
                else:
                    outputs = model(inputs)
                    outputs_adv = model(inputs_adv)

                _, feat = model.net(inputs, returnt='all')
                _, feat_adv = model.net(inputs_adv, returnt='all')

                _, pred = torch.max(outputs.data, 1)
                correct += torch.sum(pred == labels).item()

                _, pred_adv = torch.max(outputs_adv.data, 1)
                correct_adv += torch.sum(pred_adv == labels).item()

                total += labels.shape[0]

                for i in range(num_total_class):
                    if len(feat[labels==i])>0:
                        feat_distance[i] += torch.norm((feat[labels==i]-feat_adv[labels==i]), dim=1).sum().item()
                    for k in range(num_total_class):
                        per_class_output[i][k] += ((labels==i)&(pred==k)).sum().item()
                        per_class_output_adv[i][k] += ((labels==i)&(pred_adv==k)).sum().item()
    

        accs.append(correct/total * 100
                    if 'class-il' in model.COMPATIBILITY else 0)

        accs_adv.append(correct_adv/total *100.0
                    if 'class-il' in model.COMPATIBILITY else 0)

    per_class_output = per_class_output/10
    per_class_output_adv = per_class_output_adv/10

    feat_distance = feat_distance/1000
    model.net.train(status)
    return accs, accs_adv, per_class_output, per_class_output_adv, feat_distance


def evaluate_AA(model: ContinualModel, dataset: ContinualDataset, eps, texts=None) :
    status = model.net.training
    model.net.eval()
    x_total = None
    y_total = None
    for k, test_loader in enumerate(dataset.test_loaders):
        x_task = [x for (x, y) in test_loader]
        y_task = [y for (x, y) in test_loader]

        if x_total is None:
            x_total = torch.cat(x_task, 0)
            y_total = torch.cat(y_task, 0)
        else : 
            x_total = torch.cat((x_total, torch.cat(x_task, 0)),0)
            y_total = torch.cat((y_total, torch.cat(y_task, 0)),0)
    if texts == None:
        autoattack = AutoAttack(model, norm='Linf', eps=eps, version='standard')
    else:
        forward_pass = functools.partial(
            CLIP_image_logits,
            model=model, text_tokens=texts
        )
        autoattack = AutoAttack(forward_pass, norm='Linf', eps=eps, version='standard')
        # autoattack.attacks_to_run = ['square']
    adv_complete, robust_acc = autoattack.run_standard_evaluation(x_total, y_total)
    model.eval()
    model.net.train(status)
    return  robust_acc * 100



def evaluate_curvature(model: ContinualModel, dataset: ContinualDataset, mean, std, last=False) :
    '''
    23-05-18 seungju
    need to install pytorch-hessian-eigenthings
    pip install --upgrade git+https://github.com/noahgolmant/pytorch-hessian-eigenthings.git@master#egg=hessian-eigenthings
    'https://github.com/noahgolmant/pytorch-hessian-eigenthings'
    '''
    
    num_eigenthings = 20
    loss = nn.CrossEntropyLoss()
    status = model.net.training
    model.net.eval()
    eigenval_norms = []



    for k, test_loader in enumerate(dataset.test_loaders):
        if last and k < len(dataset.test_loaders) - 1:
            continue


        eigenvals, eigenvecs = compute_hessian_eigenthings(model.net, test_loader, loss, num_eigenthings)
        eigenval_norms.append(np.linalg.norm(eigenvals))
        
           
        # accs.append(correct/total * 100
        #             if 'class-il' in model.COMPATIBILITY else 0)

    model.net.train(status)
    return eigenval_norms



def evaluate_curvature_input(model: ContinualModel, dataset: ContinualDataset, texts=None, last=False, args = None) :
    
    num_eigenthings = 20
    loss = nn.CrossEntropyLoss()
    status = model.training
    model.eval()
    eigenval_norms = []
    XENT_loss = nn.CrossEntropyLoss()


    curvatures = [] 
    gradient_norms = []
    for k, test_loader in enumerate(dataset.test_loaders):
        curvature = 0 
        gradient_norm = 0
        for x,y in test_loader : 
            x, y = x.cuda(), y.cuda()
            N,_,_,_ = x.shape
            
            x.requires_grad = True
            h = 0.01 # 0.01
            if texts == None:
                out = model(x)
            else:
                out = model(x, texts)

            loss = XENT_loss(out,y)
            grad_1 = torch.autograd.grad(loss, [x], retain_graph=True, create_graph=True)[0]
            z_ = grad_1/grad_1.reshape(N,-1).norm(dim = 1).reshape(-1,1,1,1)
            
            x_hat = x.detach() + h * z_.detach()
            x_hat.requires_grad = True

            if texts == None:
                out_hat = model(x_hat)
            else:
                out_hat = model(x_hat, texts)
            loss = XENT_loss(out_hat,y)
            grad_2 = torch.autograd.grad(loss, [x_hat], retain_graph=True, create_graph=True)[0]
            
            curvature += (((grad_1 - grad_2)**2)/h**2).detach().sum().item()
            gradient_norm += grad_1.norm().detach().sum().item()
            
        curvatures.append(curvature)
        gradient_norms.append(gradient_norm)
        


    model.train(status)
    return gradient_norms, curvatures

def evaluate_gf_cf(model: ContinualModel, dataset: ContinualDataset, texts=None, last=False, args = None, MODELS = []) :

    num_eigenthings = 20
    loss = nn.CrossEntropyLoss()
    status = model.training
    model.eval()
    eigenval_norms = []
    XENT_loss = nn.CrossEntropyLoss()
    path = Path(os.path.realpath(__file__))


    GradForgetting = []
    CurvatureForgetting = []
    num_test_examples = 0


    for t, test_loader in enumerate(dataset.test_loaders):
        curvature = 0
        gradient_norm = 0
        past_model = MODELS[t]


        gf = 0
        gc = 0
        past_model.eval()
        model.eval()

        for idx,(x,y) in enumerate(test_loader) :
            x, y = x.cuda(), y.cuda()

            x.requires_grad = True
            h = 1 # 0.01

            # calculate for correct example
            if texts == None:
                out = past_model(x)
            else:
                out = past_model(x, texts)
            idx = out.argmax(dim = 1) == y

            x,y = x[idx],y[idx]
            N,_,_,_ = x.shape

            if N == 0:
                continue
            if texts == None:
                out = model(x)
            else:
                out = model(x, texts)
            loss = XENT_loss(out,y)
            grad_1_cur = torch.autograd.grad(loss, [x], retain_graph=True, create_graph=True)[0]
            #Normalize
            grad_1_cur_norm = grad_1_cur.reshape(N,-1).norm(dim = 1).reshape(-1,1,1,1)
            grad_1_cur_norm[grad_1_cur_norm == 0] = 1
            grad_1_cur = grad_1_cur/grad_1_cur_norm

            if texts == None:
                out = past_model(x)
            else:
                out = past_model(x, texts)
            loss = XENT_loss(out,y)
            grad_1_past = torch.autograd.grad(loss, [x], retain_graph=True, create_graph=True)[0]
            #Normalize
            grad_1_past_norm = grad_1_past.reshape(N,-1).norm(dim = 1).reshape(-1,1,1,1)
            grad_1_past_norm[grad_1_past_norm == 0] = 1
            grad_1_past = grad_1_past/grad_1_past_norm

            gf += ((grad_1_past - grad_1_cur)**2).detach().cpu().sum().item()

            z_past = grad_1_past
            x_hat_past = x.detach() + h * z_past.detach()
            x_hat_past.requires_grad = True

            if texts == None:
                out_hat_past = model(x_hat_past)
            else:
                out_hat_past = model(x_hat_past, texts)
            loss = XENT_loss(out_hat_past,y)
            grad_2_past = torch.autograd.grad(loss, [x_hat_past], retain_graph=True, create_graph=True)[0]
            #Normalize
            grad_2_past_norm = grad_2_past.reshape(N,-1).norm(dim = 1).reshape(-1,1,1,1)
            grad_2_past_norm[grad_2_past_norm == 0] = 1
            grad_2_past = grad_2_past/grad_2_past_norm


            curvature_past = (grad_1_past - grad_2_past)/h

            z_cur = grad_1_cur
            x_hat_cur = x.detach() + h * z_cur.detach()
            x_hat_cur.requires_grad = True

            if texts == None:
                out_hat_cur = model(x_hat_cur)
            else:
                out_hat_cur = model(x_hat_cur, texts)
            loss = XENT_loss(out_hat_cur,y)
            grad_2_cur = torch.autograd.grad(loss, [x_hat_cur], retain_graph=True, create_graph=True)[0]
            #Normalize
            grad_2_cur_norm = grad_2_cur.reshape(N,-1).norm(dim = 1).reshape(-1,1,1,1)
            grad_2_cur_norm[grad_2_cur_norm == 0] = 1
            grad_2_cur = grad_2_cur/grad_2_cur_norm



            curvature_cur = (grad_1_cur - grad_2_cur)/h

            gc += ((curvature_past - curvature_cur)**2).detach().cpu().sum().item()

            num_test_examples += N
            # print (gf,gc)

        GradForgetting.append(gf)
        CurvatureForgetting.append(gc)


    GradForgetting = (sum(GradForgetting)/num_test_examples)
    CurvatureForgetting = (sum(CurvatureForgetting)/num_test_examples)
    model.train(status)
    # print (GradForgetting, CurvatureForgetting)

    return GradForgetting, CurvatureForgetting


def CLIP_image_logits(images, model, text_tokens):
    with torch.no_grad():
        if model.NAME == 'proof':
            text_features = model.net.convnet.encode_text(text_tokens)
        elif model.NAME == 'attriclip':
            text_features = text_tokens
        else:
            text_features = model.net.encode_text(text_tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    logits = model(images, text_features)
    return logits


