
import os
import torch
import argparse
import datasets
import models
import pandas as pd 
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader

parser = argparse.ArgumentParser()
parser.add_argument("-net", type=str, required=True, help="net type")
parser.add_argument("-method",type=str,required=True)
parser.add_argument(
    "-dataset",
    type=str,
    required=True,
    nargs="?",
    choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition", "TinyImagenet", "Svhn"],
    help="dataset to train on",
)
parser.add_argument("-forget_class",type=str)
parser.add_argument("-seed",type=int)
parser.add_argument("-task",type=str, required=True)
args =  parser.parse_args()
device = 'cuda'
img_size = 32
batch_size = 1024
if args.dataset.startswith('Cifar'): classes = int(args.dataset[5:]) 
elif args.dataset== 'PinsFaceRecognition': classes, img_size = 105, 64
elif args.dataset == 'TinyImagenet': classes, img_size = 200, 64
elif args.dataset == 'Svhn': classes = 10
if args.net == "ViT": img_size = 224 

data_root = 'data/pins_face_recognition' if args.dataset == 'PinsFaceRecognition' else 'data/'
assert os.path.exists(data_root)
trainset = getattr(datasets, args.dataset)(root=data_root, download=True, train=True, unlearning=True, img_size=img_size)
train_dl = DataLoader(trainset,  batch_size=batch_size,shuffle=True)
valset = getattr(datasets, args.dataset)(root=data_root, download=True, train=False, unlearning=True, img_size=img_size)
val_dl = DataLoader(valset,  batch_size=batch_size,shuffle=True)



net = getattr(models, args.net)(num_classes=classes)
retrained_net = getattr(models, args.net)(num_classes=classes)
weight_path = f'tmp_save/{args.task}/{args.method}-{args.net}-{args.dataset}-{args.forget_class}.pt'
retrained_weight_path = f'ckp/{args.task}/{args.net}-{args.dataset}-retrain-{args.forget_class}-{args.seed}.pth'
net.load_state_dict(torch.load(weight_path))
retrained_net.load_state_dict(torch.load(retrained_weight_path))
net = net.cuda()
retrained_net = retrained_net.cuda()


def compute_kl(dataloader):
    kldiv = 0
    num = 0
    with torch.no_grad():
        for (images, _, labels) in tqdm(dataloader):
            num += images.size(0)
            images, labels = images.to(device), labels.to(device)
            unlearned_output = net(images)
            retrained_output = retrained_net(images)
            kldiv+= F.kl_div(unlearned_output.softmax(dim=-1).log(), retrained_output.softmax(dim=-1), reduction='sum')
    kldiv /= num
    return kldiv

train_kldiv = compute_kl(train_dl)
val_kldiv = compute_kl(val_dl)
print('loader ratio', len(train_dl)/len(val_dl))

kldir = f'/opt/data/private/MUMis-pami/figure/kldiv/{args.task}'
os.makedirs(kldir, exist_ok=True)
filename = f'{kldir}/{args.net}-{args.dataset}-{args.forget_class}.csv'
if os.path.exists(filename): results_df = pd.read_csv(filename)
else: results_df = pd.DataFrame()
new_row = pd.DataFrame({'Method': [args.method],
                        "Train KL Div": [train_kldiv.detach().cpu().numpy()],                      
                        "Val KL Div": [val_kldiv.detach().cpu().numpy()], 
                        "All KL Div":[train_kldiv.detach().cpu().numpy()*0.8+val_kldiv.detach().cpu().numpy()*0.2]                  
                        })
new_row.to_csv(filename, mode='a', index=False, header=not os.path.exists(filename))
