from argparse import Namespace
from pathlib import Path
import os
import copy
import time

import torch
from torch.utils.data import DataLoader
import torch.nn as nn

import pandas as pd


from salun_src.utils import accuracy, prepare_dataset
from salun_src.evaluation import validate,SVC_MIA
from CKA_src.cka import CKACalculator

RAND_FORGET_INDICES={'cifar10':4500,'svhn':6593}
def cka_compute(model1, model2, sample_dataset:DataLoader):
    # model1 = models_all[source_un_name]
    # model2 = models_all[target_un_name]
    calculator = CKACalculator(model1, model2, sample_dataset,hook_layer_types=(model1.fc.__class__,),num_epochs=1,)
    cka_matrix = calculator.calculate_cka_matrix()
    calculator.reset()
    return cka_matrix
def load_model(model,ckpt_path:os.PathLike):
    # model = resnet18(num_classes=10)
    ckpt=torch.load(ckpt_path,map_location='cpu')

    if ckpt['fc.weight'].shape!=model.fc.weight.shape: # 11dims for BE
        model.fc.weight.data = ckpt['fc.weight']
        if getattr(model.fc,'bias',None) is not None:
            model.fc.bias.data = ckpt['fc.bias']

    model.load_state_dict(ckpt)
    model = model.to('cuda')
    model.eval()
    return model
def evaluate_scores(model,dataset_all,criterion=nn.CrossEntropyLoss(),args=Namespace(print_freq=50),is_class_unlearning=True):
    test_len = len(dataset_all['test'].dataset)
    forget_len = len(dataset_all['forget'].dataset)
    retain_len = len(dataset_all['retain'].dataset)

    shadow_train = torch.utils.data.Subset(dataset_all['retain'].dataset, list(range(test_len)))
    shadow_train_loader = torch.utils.data.DataLoader(
                shadow_train, batch_size=dataset_all['forget'].batch_size, shuffle=False)
    ret={}
    for data_name, loader in dataset_all.items():
        if data_name in ['train','val']:
            continue
        if is_class_unlearning and data_name == 'test':
            continue
        ret[f'{data_name}_acc'] = validate(val_loader=loader,model=model.to('cuda'),criterion=criterion, args=args)

    ret['SVC_MIA_forget_efficacy'] = SVC_MIA(
        shadow_train=shadow_train_loader,shadow_test=dataset_all['test'],
        target_train=None,
        target_test=dataset_all['forget'],
        model=model,).get('confidence',None)
    return ret
def feature_lstsq(teacher_model,unlearned_model,sample_dataloader:DataLoader,lstsq_device ='cpu'):
    pred_head = getattr(teacher_model,'head',teacher_model.fc)
    ret = torch.nn.Linear(pred_head.in_features,pred_head.in_features,bias=False)

    '''find linear layer f(x)=Ax+b such that \sum_i (f(p_i)-q_i)^2 is minimized.
    p_i = unlearned_model(x_i)
    q_i = teacher_model(x_i)
    '''
    # get the features of the sample dataset
    teacher_features = []
    unlearned_features = []

    with torch.no_grad():
        for data, _ in sample_dataloader:
            data = data.to('cuda')
            teacher_out = teacher_model(data)
            unlearned_out = unlearned_model(data)
            teacher_features.append(teacher_out.to(lstsq_device))
            unlearned_features.append(unlearned_out.to(lstsq_device))
    X = torch.cat(unlearned_features, dim=0)
    Y = torch.cat(teacher_features, dim=0)

    XTX = X.T @ X
    XTY = X.T @ Y

    M_T = torch.linalg.solve(XTX, XTY)
    ret.weight.data = M_T.T
    return ret
def lstsq_head(unlearned_model,sample_dataloader:DataLoader,lstsq_device ='cpu',normalize=False):
    head = getattr(unlearned_model,'head',unlearned_model.fc)
    num_features= head.in_features
    num_classes = head.out_features
    new_head = torch.nn.Linear(num_features,num_classes,bias=True)

    '''find linear layer f(x)=Ax+b such that \sum_i (f(p_i)-q_i)^2 is minimized.
    p_i = unlearned_model(x_i)
    q_i = teacher_model(x_i)
    '''
    # get the features of the sample dataset
    teacher_features = []
    unlearned_features = []

    with torch.no_grad():
        for data, label in sample_dataloader:
            data = data.to('cuda')
            # teacher_out = teacher_model(data)
            unlearned_out = unlearned_model(data)
            if not normalize:
                unlearned_out = unlearned_out / unlearned_out.norm(dim=1)[:,None]
            if getattr(head,'bias',None) is not None:
                unlearned_out=torch.concat([unlearned_out,torch.ones((data.shape[0],1),device=unlearned_out.device)],dim=1)
            # teacher_features.append(teacher_out.to(lstsq_device))
            unlearned_features.append(unlearned_out.to(lstsq_device))

            target_logits =  torch.zeros((data.shape[0], num_classes), device=label.device)
            target_logits.scatter_(1, label.view(-1, 1), 1)
            teacher_features.append(target_logits.to(lstsq_device,dtype=torch.float32))
            
    X = torch.cat(unlearned_features, dim=0)
    Y = torch.cat(teacher_features, dim=0)

    XTX = X.T @ X
    XTY = X.T @ Y

    M_T = torch.linalg.solve(XTX, XTY)

    if getattr(head,'bias',None) is not None:
        new_head.weight.data = M_T.T[:,:-1]
        new_head.bias.data = M_T.T[:,-1]
    else:
        new_head.weight.data = M_T.T
    return new_head
def TF_head(model,retain_loader,forget_loader,forget_target_type='EU',lstsq_device ='cpu'):
    head = getattr(model,'head',model.fc)
    num_features= head.in_features
    num_classes = head.out_features
    new_head = torch.nn.Linear(num_features,num_classes,bias=True)

    '''find linear layer f(x)=Ax+b such that \sum_i (f(p_i)-q_i)^2 is minimized.
    p_i = unlearned_model(x_i)
    q_i = teacher_model(x_i)
    '''
    # get the features of the sample dataset
    model_features = []
    model_targets = []

    with torch.no_grad():
        for data, label in retain_loader:
            data = data.to('cuda')
            # teacher_out = teacher_model(data)
            model_feature_raw = model(data)
            model_feature = model_feature_raw #/ model_feature.norm(dim=1)[:,None]
            # teacher_features.append(teacher_out.to(lstsq_device))
            if getattr(head,'bias',None) is not None:
                model_feature=torch.concat([model_feature,torch.ones((data.shape[0],1),device=model_feature.device)],dim=1)
            model_features.append(model_feature.to(lstsq_device))

            target_logits =  torch.zeros((data.shape[0], num_classes), device=label.device)            
            target_logits.scatter_(1, label.view(-1, 1), 1) #onehot
            model_targets.append(target_logits.to(lstsq_device,dtype=torch.float32))
        for data, _ in forget_loader:
            data = data.to('cuda')
            # teacher_out = teacher_model(data)
            model_feature = model(data)
            model_feature = model_feature #/ model_feature.norm(dim=1)[:,None]
            # teacher_features.append(teacher_out.to(lstsq_device))
            if getattr(head,'bias',None) is not None:
                model_feature=torch.concat([model_feature,torch.ones((data.shape[0],1),device=model_feature.device)],dim=1)
            model_features.append(model_feature.to(lstsq_device))
            
            if forget_target_type == 'OPC':
                target_logits =  torch.zeros((data.shape[0], num_classes), device=data.device) # EU setting
            elif forget_target_type == 'RL':
                label = torch.randint(0,num_classes,(data.shape[0],)) # RL setting
                target_logits =  torch.zeros((data.shape[0], num_classes), device=label.device)            
                target_logits.scatter_(1, label.view(-1, 1), 1)
            model_targets.append(target_logits.to(lstsq_device,dtype=torch.float32))
            
    X = torch.cat(model_features, dim=0)
    Y = torch.cat(model_targets, dim=0)

    XTX = X.T @ X
    XTY = X.T @ Y

    M_T = torch.linalg.solve(XTX, XTY)
    if getattr(head,'bias',None) is not None:
        new_head.weight.data = M_T.T[:,:-1]
        new_head.bias.data = M_T.T[:,-1]
    else:
        new_head.weight.data = M_T.T
    return new_head
if __name__ == '__main__':
    pretrained_ckpt_path=Path(__file__).parent/'ckpt'/'pretrained_cifar10.pth'
    unlearned_ckpt_path = Path(__file__).parent/'ckpt'/'OPC_cifar10_cls30.pth'

    args=Namespace(arch='resnet18',batch_size=128,data='../data',workers=4,class_to_replace=None,num_indexes_to_replace=None,indexes_to_replace=None,seed=2,no_aug=False,train_seed=None,imagenet_arch=False,)

    args.dataset = unlearned_ckpt_path.stem.split('_')[1]
    hr_normalize=False

    if unlearned_ckpt_path.stem.split('_')[-1].startswith('cls'):
        args.class_to_replace = [i for i in range(3)]
        is_class_unlearning=True
        hr_normalize=False
    elif unlearned_ckpt_path.stem.split('_')[-1].startswith('rand'):
        args.num_indexes_to_replace = RAND_FORGET_INDICES[args.dataset]
        is_class_unlearning=False
    dataset_all, model= prepare_dataset(args)
    pretrained_model = load_model(copy.deepcopy(model),pretrained_ckpt_path)

    model = load_model(model,unlearned_ckpt_path)
    results_all = {}
    CKA_res ={}

    results_all['pretrained'] = evaluate_scores(pretrained_model,dataset_all,is_class_unlearning=is_class_unlearning)
    results_all['OPC'] = evaluate_scores(model,dataset_all,is_class_unlearning=is_class_unlearning)

    pretrained_model.register_module('head', copy.deepcopy(pretrained_model.fc))

    # CKA computation
    CKA_res['forget_logit'] = cka_compute(pretrained_model,model,sample_dataset=dataset_all['forget'])[-1,-1].item()
    CKA_res['retain_logit'] = cka_compute(pretrained_model,model,sample_dataset=dataset_all['retain'])[-1,-1].item()

    pretrained_model.register_module('fc', nn.Identity())
    if not hasattr(model,'head'):
        model.register_module('head', copy.deepcopy(model.fc))
    model.register_module('fc', nn.Identity())
    CKA_res['forget_feature'] = cka_compute(pretrained_model,model,sample_dataset=dataset_all['forget'])[-1,-1].item()
    CKA_res['retain_feature'] = cka_compute(pretrained_model,model,sample_dataset=dataset_all['retain'])[-1,-1].item()


    # Feature mapping attack

    ## classifier removal
    
    pretrained_model.register_module('fc', nn.Identity())
    if not hasattr(model,'head'):
        model.register_module('head', copy.deepcopy(model.fc))
    model.register_module('fc', nn.Identity())
    lstsq_start =time.time()
    fm_layer = feature_lstsq(
        teacher_model=pretrained_model,
        unlearned_model=model,
        sample_dataloader=dataset_all['val'],
        lstsq_device='cpu')
    print(f'feature mapping recovery took {time.time()-lstsq_start} sec')
    model.register_module('fc',nn.Sequential(
                copy.deepcopy(fm_layer),
                copy.deepcopy(pretrained_model.head)
            ))
    results_all['FM-recovery'] = evaluate_scores(model,dataset_all,is_class_unlearning=is_class_unlearning)

    # Head recovery
    ## classifier removal
    model.register_module('fc', nn.Identity())
    lstsq_start =time.time()
    hr_layer = lstsq_head(unlearned_model=model,
                          sample_dataloader=dataset_all['val'],
                          lstsq_device='cpu',
                          normalize=hr_normalize)
    print(f'head recovery took {time.time()-lstsq_start} sec')

    model.register_module('fc',hr_layer)
    results_all['HR-recovery'] = evaluate_scores(model,dataset_all,is_class_unlearning=is_class_unlearning)

    # train-free unlearning
    pretrained_model.register_module('fc',pretrained_model.head)

    
    for forget_type in ['OPC','RL']:
        # no unlearned model now
        model = copy.deepcopy(pretrained_model)
        model.register_module('fc', nn.Identity())
        lstsq_start =time.time()
        new_head = TF_head(model,dataset_all['retain'],dataset_all['forget'],forget_type)
        print(f'training-free unlearning {forget_type} took {time.time()-lstsq_start} sec')
        model.register_module('fc',new_head)

        results_all[forget_type+'-TF'] = evaluate_scores(model,dataset_all,is_class_unlearning=is_class_unlearning)


    result_df = pd.DataFrame(results_all)
    print(result_df.T.to_markdown())

    print('CKA results compared pretrained:',CKA_res)
    print(1)