from argparse import Namespace
from pathlib import Path
import os
import copy
import time

import torch
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from salun_src.utils import accuracy, prepare_dataset

import inversion_src.recovery as rs

RAND_FORGET_INDICES={'cifar10':4500,'svhn':6593}
def new_plot(tensor, title="", path=None):
    if tensor.shape[0] == 1:
        return plt.imshow(tensor[0].permute(1, 2, 0).cpu())
    else:
        fig, axes = plt.subplots(1, tensor.shape[0], figsize=(2 * tensor.shape[0], 3))
        for i, im in enumerate(tensor):
            axes[i].imshow(im.permute(1, 2, 0).cpu())
    plt.title(title)
    plt.savefig(path)

def process_recons_results(result, ground_truth, figpath, recons_path, filename):
    output_list, stats, history_list, x_optimal = result
    x_optimal = x_optimal.detach().cpu()
    test_mse = (x_optimal - ground_truth.cpu()).pow(2).mean()

    title = f"MSE: {test_mse:2.4f}"
    new_plot(torch.cat([ground_truth, x_optimal]), title, path=os.path.join(figpath, f'{filename}.png'))
    torch.save({'output_list': output_list.cpu(), 'stats': stats, 'history_list': history_list, 'x_optimal': x_optimal}, open(os.path.join(recons_path, f'{filename}.pth'), 'wb'))
def load_model(model,ckpt_path:os.PathLike):
    # model = resnet18(num_classes=10)
    ckpt=torch.load(ckpt_path,map_location='cpu')

    if ckpt['fc.weight'].shape!=model.fc.weight.shape: # 11dims for BE
        model.fc.weight.data = ckpt['fc.weight']
        if getattr(model.fc,'bias',None) is not None:
            model.fc.bias.data = ckpt['fc.bias']

    model.load_state_dict(ckpt)
    model = model.to('cuda')
    model.eval()
    return model

if __name__ =='__main__' :
    pretrained_ckpt_path=Path(__file__).parent/'ckpt'/'pretrained_cifar10.pth'
    unlearned_ckpt_path = Path(__file__).parent/'ckpt'/'OPC_cifar10_cls30.pth'

    args=Namespace(arch='resnet18',batch_size=128,data='../data',workers=4,class_to_replace=None,num_indexes_to_replace=None,indexes_to_replace=None,seed=2,no_aug=False,train_seed=None,imagenet_arch=False,)

    args.dataset = unlearned_ckpt_path.stem.split('_')[1]
    hr_normalize=False

    if unlearned_ckpt_path.stem.split('_')[-1].startswith('cls'):
        args.class_to_replace = [i for i in range(3)]
        is_class_unlearning=True
        hr_normalize=False
    elif unlearned_ckpt_path.stem.split('_')[-1].startswith('rand'):
        args.num_indexes_to_replace = RAND_FORGET_INDICES[args.dataset]
        is_class_unlearning=False
    dataset_all, model= prepare_dataset(args)
    models_all ={}
    models_all['pretrained']= load_model(copy.deepcopy(model),pretrained_ckpt_path)

    models_all['OPC'] = load_model(model,unlearned_ckpt_path)

    recons_config = dict(signed=True,
                boxed=True,
                cost_fn='sim',
                indices='def',
                weights='equal',
                lr=0.02, 
                optim='adamw',
                restarts=1,
                max_iterations=5000,
                total_variation=1e-2,
                init='randn',
                filter='none',
                lr_decay=True,
                scoring_choice='loss')
    dmlist = [0.4914672374725342, 0.4822617471218109, 0.4467701315879822]
    dslist = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324]
    dm =torch.as_tensor(dmlist,device='cuda',dtype=torch.float32)[:,None,None]
    ds = torch.as_tensor(dslist,device='cuda',dtype=torch.float32)[:,None,None]
    
    results_saveroot=Path(__file__).parent/'recon'
    os.makedirs(results_saveroot,exist_ok=True)
    setup=dict(device='cuda',dtype=torch.float32)
    for unlearn_item_ind in range(10):
        saveroot = results_saveroot/f'forget_{unlearn_item_ind}'
        os.makedirs(saveroot,exist_ok=True)
        os.makedirs(saveroot/'recon',exist_ok=True)
        os.makedirs(saveroot/'figures',exist_ok=True)
        for unlearn_id in models_all.keys():
            print(unlearn_id)

            if saveroot/'recon'/f'unlearn_unlearn_{unlearn_id}_{unlearn_item_ind}.pth' in saveroot.rglob('*'):
                print('Already done')
                continue

            rec_machine_pretrain=rs.GradientReconstructor(models_all[unlearn_id], (torch.zeros_like(dm), torch.ones_like(ds)), recons_config, num_images=1)
            rec_machine_pretrain.model.eval()

            X_unlearn =dataset_all['forget'].dataset[unlearn_item_ind][0].unsqueeze(0).to('cuda')
            y_unlearn =torch.tensor(dataset_all['forget'].dataset[unlearn_item_ind][1]).unsqueeze(0).to('cuda')
            approx_diff = [p.detach().to(**setup) for p in rs.recovery_algo.loss_steps(models_all[unlearn_id],(X_unlearn),y_unlearn,lr=1e-4,local_steps=1,
            # loss_fn=lambda *x: x[0].norm(p=2,dim=1)
            )]

            # approx_diff= [-(ft_param.detach().to(**setup) - org_param.detach().to(**setup)+1e-8).detach() for (ft_param, org_param) in zip(models_all[unlearn_id].parameters(), pretrained_model.parameters())]

            result_approx = rec_machine_pretrain.reconstruct(approx_diff, (X_unlearn.to(**setup)), y_unlearn.to(setup['device']), img_shape=(3, 32, 32))

            
            plt.clf()
            process_recons_results(result_approx, X_unlearn.cpu(), figpath=saveroot/'figures', recons_path=saveroot/'recon', filename=f'unlearn_unlearn_{unlearn_id}_{unlearn_item_ind}')
            
