import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import argparse
from tensorboardX import SummaryWriter
import torchvision.transforms as transforms
import torchvision
from models.resnet import ResNet18
from models.resnet_orig import ResNet18_orig
from models.vgg import VGG, VGG_rw
from datasets import get_dataset, unnormalize
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

import copy
import random
import numpy as np
import time
import sys
import os
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.append(parentdir)

from utils_ensemble import test
from trainer import Naive_Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device: ', device)


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# parser.add_argument('--dataset', default='cifar', type=str, choices=DATASETS)
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--unlearn_method', default='RW', type=str)
parser.add_argument('--arch', default='ResNet18', type=str)
parser.add_argument('--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=201, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--batch', default=128, type=int, metavar='N', help='batchsize (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=40, help='How often to decrease learning by gamma.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--num-models', default=3, type=int)
parser.add_argument('--resume', action='store_true', help='if true, tries to resume training from existing checkpoint')
parser.add_argument('--unnormalize', default=True, type=bool)
parser.add_argument('--adv-training', action='store_true')
parser.add_argument('--epsilon', default=512, type=float)
parser.add_argument('--num-steps', default=4, type=int)
parser.add_argument('--adv-eps', default=0.04, type=float)

parser.add_argument('--epsilon_t', default=0.031,
                    help='perturbation')
parser.add_argument('--num-steps_t', default=10,
                    help='perturb number of steps')
parser.add_argument('--step-size_t', default=0.007,
                    help='perturb step size')
parser.add_argument('--beta_t', default=6.0,
                    help='regularization, i.e., 1/lambda in TRADES')

parser.add_argument('--method', default='orig', help='clipping method (use orig for no clipping)')
parser.add_argument('--mode', default='wBN', help='what to do with BN layers (leave empty for keeping it as it is)')
parser.add_argument('--seed', default=1, type=int, help='seed value')
parser.add_argument('--convsn', default=1., type=float, help='clip value for conv and dense layers')

parser.add_argument('--widen_factor', default=1, type=int, help='widen factor for WideResNet')

parser.add_argument('--coeff', default=2.0, type=float)
parser.add_argument('--lamda', default=2.0, type=float)
parser.add_argument('--scale', default=5.0, type=float)
parser.add_argument('--plus-adv', action='store_false')
# parser.add_argument('--init-eps', default=0.1, type=float)
parser.add_argument('--init-eps', default=0.01, type=float)

args = parser.parse_args()

if args.adv_training:
    mode = f"adv_{args.epsilon}_{args.num_steps}"
else:
    if args.method == 'orig':
        mode = f"vanilla_orig_{args.mode}"
    else:
        mode = f"vanilla_clip{args.convsn}_{args.mode}"
if args.unlearn_method == 'retrain' or args.unlearn_method == 'RW_FT' or args.unlearn_method == 'RW_FT_par' or args.unlearn_method == 'BS':
    model_arch = "original"   
elif args.unlearn_method == 'RW' or args.unlearn_method == 'RW_multi':
    model_arch = "layers"
args.outdir = f"/{args.dataset}_unnorm_{model_arch}/{args.arch}_{mode}_{args.seed}/"

args.epsilon /= 256.0

if (args.resume):
    args.outdir = "resume" + args.outdir
else:
    args.outdir = "scratch" + args.outdir


# args.outdir = "/class_unlearn/logs/correct/" + args.outdir
args.outdir = "/class_unlearn/logs/RW/" + args.outdir

print(args.outdir)
print('learning rate: ', args.lr)
print('dataset: ', args.dataset)
print('unlearn_method: ', args.unlearn_method)

def main():
    elu_flag     = False #### for elu activation ----------------------------------
    clip_flag    = False
    orig_flag    = False

    seed_val = args.seed
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    np.random.seed(seed_val)
    random.seed(seed_val)

    mode = args.mode
    bn_flag = True
    opt_iter = 5
    clip_steps = 50
    if mode == 'wBN':
        mode = ''
        bn_flag = True
    elif mode == 'noBN':
        bn_flag = False
        opt_iter = 1
        clip_steps = 100
        args.epochs = 121

    if args.method == 'orig':
        orig_flag    = True
    else:
        clip_flag    = True

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    if args.dataset == 'mnist':
        print('using mnist')
        in_chan = 1
        num_classes = 10
        if args.unnormalize:
            if args.arch == 'ResNet18':
                transform_train = transforms.Compose([
                    transforms.ToTensor(),
                ])
                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            elif args.arch == 'VGG':
                transform_train = transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                ])
                transform_test = transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                ])
                
        else:
            if args.arch == 'ResNet18':
                transform_train = transforms.Compose([
                    transforms.Resize((28, 28)),  # Ensure images are 28x28
                    transforms.ToTensor(),        # Convert images to PyTorch tensors
                    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
                ])
                transform_test = transform_train
            elif args.arch == 'VGG':
                transform_train = transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
                transform_test = transform_train
        
        trainset = torchvision.datasets.MNIST(root='/mnist', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)
        testset = torchvision.datasets.MNIST( root='/mnist', train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)
        
    elif args.dataset == 'cifar10':
        print('using cifar 10')
        in_chan = 3
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        
        trainset = torchvision.datasets.CIFAR10(root='/cifar10', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)

        testset = torchvision.datasets.CIFAR10(root='/cifar10', train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)

    elif args.dataset == 'cifar100':
        print('using cifar 100')
        in_chan = 3
        num_classes = 100
        
        if args.unnormalize:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
            ])


        trainset = torchvision.datasets.CIFAR100(root='/cifar100', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)
        testset = torchvision.datasets.CIFAR100(root='/cifar100', train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)

    elif args.dataset == 'imagenet':
        print('using Imagenet')
        in_chan = 3
        num_classes = 200

        if args.unnormalize:
            transform_train = transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
            ])

        train_dir = '/tiny-imagenet-200/train'
        val_dir = '/tiny-imagenet-200/val'
        trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
        testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_test)
        train_loader = DataLoader(trainset, shuffle=True, batch_size=256, num_workers=4)
        test_loader = DataLoader(testset, shuffle=False, batch_size=256, num_workers=4)
    

    else:
        print('args.dataset: ', args.dataset)
        train_dataset = get_dataset(args.dataset, 'train')
        test_dataset = get_dataset(args.dataset, 'test')
        pin_memory = (args.dataset == "imagenet")
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch, num_workers=args.workers, pin_memory=pin_memory)
        test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch, num_workers=args.workers, pin_memory=pin_memory)
    
    if args.dataset == 'mnist':
        args.epochs = 51
        args.adv_eps = 0.1
        
    if args.dataset == 'cifar100':
        args.epochs = 121

    flag = 0
    for label in range(num_classes):
        train_filename = f'/class_unlearn/class_indices/{args.dataset}_label_{label}.csv'
        test_filename = f'/class_unlearn/class_indices/{args.dataset}_label_{label}_test.csv'
        
        if not os.path.exists(train_filename):
            flag = 1
            break
            
        if not os.path.exists(test_filename):
            flag = 1
            break
    print("flags", flag)
    if flag == 1:
        print("saving labels")
        train_labels = defaultdict(list)
        for i in tqdm(range(len(trainset))):
            label = trainset[i][1]
            train_labels[label].append(i)

        test_labels = defaultdict(list)
        for i in tqdm(range(len(testset))):
            label = testset[i][1]
            test_labels[label].append(i)
    ### extract all the indices in trainset with label:
    for label in range(num_classes):
        train_filename = f'/class_unlearn/class_indices/{args.dataset}_label_{label}.csv'
        test_filename = f'/class_unlearn/class_indices/{args.dataset}_label_{label}_test.csv'
        
        if not os.path.exists(train_filename):
            print('label:', label, 'num:', len(train_labels[label]))
            df = pd.DataFrame(train_labels[label], columns=['unlearn_idx'])
            df.to_csv(train_filename, index=False)
        # else:
        #     print(f'Train CSV already exists for label {label}, skipping.')

            
        if not os.path.exists(test_filename):
            print('label:', label, 'num:', len(test_labels[label]))
            df = pd.DataFrame(test_labels[label], columns=['unlearn_idx'])
            df.to_csv(test_filename, index=False)
        # else:
        #     print(f'Test CSV already exists for label {label}, skipping.')
            
        

    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    writer = SummaryWriter(args.outdir)

    model = None
    # submodel = get_architecture(args.arch, args.dataset)
    print('elu flag: ', elu_flag)
    print('arch: ', args.arch)
    if clip_flag:
        if args.arch == 'ResNet18':
            submodel = ResNet18(concat_sv=False, in_chan=in_chan, device=device, clip=args.convsn, clip_flag=True, bn=bn_flag, clip_steps=clip_steps, clip_outer=False, clip_opt_iter=opt_iter, summary=True, writer=writer, save_info=False, elu_flag=elu_flag, identifier=1000)
    elif orig_flag:
        if args.arch == 'ResNet18':
            if args.dataset == 'imagenet':
                submodel = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=elu_flag, num_classes=num_classes, tinynet=True, unlearn_method=args.unlearn_method)
            else:
                submodel = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=elu_flag, num_classes=num_classes, unlearn_method=args.unlearn_method)
        elif args.arch == 'VGG':
            if args.unlearn_method == 'RW' or args.unlearn_method == 'RW_multi':
                if args.dataset == 'imagenet':
                    submodel = VGG_rw('VGG19', in_chan=in_chan, num_classes=num_classes, tinynet=True)
                else:
                    submodel = VGG_rw('VGG19', in_chan=in_chan, num_classes=num_classes)
            elif args.unlearn_method == 'BS' or args.unlearn_method == 'retrain' or args.unlearn_method == 'RW_FT' or args.unlearn_method == 'RW_FT_par':
                if args.dataset == 'imagenet':
                    submodel = VGG('VGG19', in_chan=in_chan, num_classes=num_classes, tinynet=True)
                else:
                    submodel = VGG('VGG19', in_chan=in_chan, num_classes=num_classes)
    submodel = nn.DataParallel(submodel)
    model = submodel
    print("Model loaded")

    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)


    if (args.resume):
        base_classifier = "logs/Empirical/scratch/" + args.dataset + "/vanilla/checkpoint.pth.tar"
        print(base_classifier)
        for i in range(3):
            checkpoint = torch.load(base_classifier + ".%d" % (i))
            print("Load " + base_classifier + ".%d" % (i))
            model[i].load_state_dict(checkpoint['state_dict'])
            model[i].train()
        print("Loaded...")

    loss_acc_list = []
    best_acc = 0.0
    print("START WORKING", args.epochs)
    for epoch in range(args.epochs):
        start = time.time()
        train_loss = Naive_Trainer(args, train_loader, model, criterion, optimizer, epoch, device, writer, scheduler, unlearn_method=args.unlearn_method)
        tot_time = time.time() - start
        print('time: ', tot_time)

        # if epoch % 5 == 0:
        if True:
            test_acc, test_loss = test(test_loader, model, criterion, epoch, device, writer, unlearn_method=args.unlearn_method)
            loss_acc_list.append((epoch, train_loss, test_loss, test_acc))

            if test_acc >= best_acc:
                model_path_i = model_path + '_best'
                torch.save({
                        'epoch': epoch,
                        'arch': args.arch,
                        'scheduler': scheduler.state_dict(),
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, model_path_i)

                best_acc = test_acc
                # better_acc = True

            writer.add_scalar('test/best_acc', best_acc, epoch)


        if epoch % 10 == 0:
            model_path_i = model_path + "_%d" % (epoch)
            torch.save({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_path_i)

        scheduler.step()

    loss_acc_res_path = os.path.join(args.outdir, 'loss_acc_e' + str(args.adv_eps) + '_s' + str(args.seed) + '.csv')
    df = pd.DataFrame(loss_acc_list, columns=['epoch', 'train_loss', 'test_loss', 'test_acc'])
    df.to_csv(loss_acc_res_path)

if __name__ == "__main__":
    main()


