import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
import torchvision
import torch_pruning as tp
from mlh.defenses.membership_inference.Normal import TrainTargetNormal
from mlh.defenses.membership_inference.pruner import GradGapPruner
from models.models_non_image import Purchase,Texas
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from mlh.data_preprocessing.data_loader import GetDataLoader
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import argparse
import numpy as np
import torch.optim as optim
torch.manual_seed(0)
np.random.seed(0)
torch.set_num_threads(1)


def parse_args():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--batch-size', type=int, default=512,
                        help='batch_size')
    parser.add_argument('--num-workers', type=int, default=10,
                        help='num of workers to use')
    subparsers = parser.add_subparsers(dest='training_type', required=False)
    # Parser for PAST
    parser_h = subparsers.add_parser('Reg')
    parser_h.add_argument('--reg_weight', type=float, default=1e-5, help='')
    parser_h.add_argument('--reg_alpha', type=float, default=4, help='')
    parser_h.add_argument('--reg_epoch', type=int, default=50, help='')
    parser_h.add_argument('--reg_clamp', type=int, default=10000, help='')
    parser_h.add_argument('--reg_norm', type=str, default="l1", help='')
    
    parser.add_argument('--mode', type=str, default="shadow",
                        help='target, shadow')

    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')
    parser.add_argument('--weight_l2', type=float, default=5e-04, help='')
    parser.add_argument('--lr', type=float, default=0.01, help='')
    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu index used for training')
    
    
    # pruning
    parser.add_argument('--prune', type=str, default="f",
                        help='t(true), f(false)')
    parser.add_argument('--pruner', type=str, default="norm",
                        help='norm, tylor, hessian, mia')
    parser.add_argument('--global_pruning', type=str, default="f",
                        help='t(true), f(false)')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--load-pretrained', type=str, default='no')
    parser.add_argument('--task', type=str, default='mia',
                        help='specify the attack task, mia or ol')
    parser.add_argument('--dataset', type=str, default='CIFAR10',
                        help='dataset')
    parser.add_argument('--num_class', type=int, default=10,
                        help='number of classes')
    parser.add_argument('--inference-dataset', type=str, default='CIFAR10',
                        help='if yes, load pretrained the attack model to inference')
    parser.add_argument('--data-path', type=str, default='../datasets/',
                        help='data_path')
    parser.add_argument('--input-shape', type=str, default="32,32,3",
                        help='comma delimited input shape input')
    parser.add_argument('--log_path', type=str,
                        default='./save', help='data_path')

    args = parser.parse_args()
    
    if args.training_type is None:
        args.training_type = 'Normal'

    args.input_shape = [int(item) for item in args.input_shape.split(',')]
    # args.device = 'cuda:%d' % args.gpu if torch.cuda.is_available() else 'cpu'
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    return args


def get_target_model(name="resnet18", num_classes=10):
    if name == "resnet18":
        if num_classes==100:
            model = torchvision.models.resnet18(pretrained=True)
        else:
            model = torchvision.models.resnet18()
        model.fc = nn.Sequential(nn.Linear(512, num_classes))
    elif name == "dense121":
        model = torchvision.models.densenet121(weights="IMAGENET1K_V1")
        # model = torchvision.models.densenet121()
        model.classifier = nn.Sequential(nn.Linear(1024, num_classes))
    elif name == "TexasClassifier":
        model= Texas(num_classes = num_classes)
    elif name == "PurchaseClassifier":
        model= Purchase(num_classes = num_classes)

    else:
        raise ValueError("Model not implemented yet :P")
    return model


def evaluate(args, model, dataloader):
    model.eval()
    correct = 0
    total = 0
    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(args.device), labels.to(args.device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        if np.isnan(np.sum(predicted)) or np.isnan(np.sum(outputs)):
            raise ValueError("Input contains NaN values.")
        correct += predicted.eq(labels).sum().item()
    model.train()
    return correct / total


if __name__ == "__main__":

    opt = parse_args()
    s = GetDataLoader(opt)
    target_train_loader, target_inference_loader, target_test_loader, shadow_train_loader, shadow_inference_loader, shadow_test_loader = s.get_data_supervised()

    if opt.mode == "target":
        train_loader, inference_loader, test_loader = target_train_loader, target_inference_loader, target_test_loader,
    elif opt.mode == "shadow":
        train_loader, inference_loader, test_loader = shadow_train_loader, shadow_inference_loader, shadow_test_loader
    else:
        raise ValueError("opt.mode should be target or shadow")

    target_model = get_target_model(name=opt.model, num_classes=opt.num_class).cuda()

    save_pth = f'{opt.log_path}/{opt.dataset}/{opt.training_type}/{opt.mode}'

    if opt.training_type == "Normal" or opt.training_type == "Reg":
        if opt.training_type == "Reg":
            save_pth_before_last_slash, save_pth_after_last_slash = save_pth.rsplit('/', 1)
            if opt.reg_norm=="l1":
                save_pth = f'{save_pth_before_last_slash}-{opt.reg_weight}-{opt.epochs}{opt.reg_epoch}-{opt.reg_clamp}_{opt.reg_alpha}/{save_pth_after_last_slash}'
            else:
                save_pth = f'{save_pth_before_last_slash}-{opt.reg_weight}-{opt.epochs}{opt.reg_epoch}-{opt.reg_clamp}-{opt.reg_norm}_{opt.reg_alpha}/{save_pth_after_last_slash}'
            
        total_evaluator = TrainTargetNormal(
            model=target_model, epochs=opt.epochs, log_path=save_pth, num_class=opt.num_class, weight_decay=opt.weight_l2)
        total_evaluator.train(train_loader, inference_loader, test_loader)
        # pass

    else:
        raise ValueError(
            "opt.training_type should be Normal, LabelSmoothing, AdvReg, DP, MixupMMD, PATE")
    
    model = target_model
    if opt.training_type == "Reg":
    # if opt.prune=="t":
        if opt.dataset in ["CIFAR10","CIFAR100"]:
            example_inputs = torch.randn(1, 3, 32, 32).to('cuda')
        elif opt.dataset=="texas":
            example_inputs = torch.randn(1, 6169).to('cuda')
        elif opt.dataset=="purchase":
            example_inputs = torch.randn(1, 600).to('cuda')
        elif (opt.dataset=="imagenet") or (opt.dataset=="imagenet_r"):
            example_inputs = torch.randn(1, 3, 224, 224).to('cuda')

        # 1. Importance criterion
        if opt.pruner=="norm":
            imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.
        elif opt.pruner=="tylor":
            imp = tp.importance.GroupTaylorImportance()
        elif opt.pruner=="hessian":
            imp = tp.importance.GroupHessianImportance()
        elif opt.pruner=="mia":
            # imp = MIAImportance()
            imp = tp.importance.GroupNormImportance(p=1)

        # 2. Initialize a pruner with the model and the importance criterion
        ignored_layers = []
        for m in model.modules():
            if isinstance(m, torch.nn.Linear) and m.out_features == 10:
                ignored_layers.append(m) # DO NOT prune the final classifier!

        pruner = GradGapPruner(
            model,
            example_inputs,
            importance=imp,
            global_pruning=True if opt.global_pruning=="t" else False,
            pruning_ratio=0.5, 
            ignored_layers=ignored_layers,
        )

        total_evaluator = TrainTargetNormal(
            model=target_model, epochs=opt.reg_epoch, log_path=save_pth,learning_rate=opt.lr)
        total_evaluator.train_sparse(train_loader,inference_loader, test_loader,pruner=pruner,args=opt)
    
    torch.save(model.state_dict(),
               os.path.join(save_pth, f"{opt.model}.pth"))
    # 4. Save & Load
    model.zero_grad() # clear gradients to avoid a large file size
    torch.save(model,
               os.path.join(save_pth, f"{opt.model}_model.pth")) # !! no .state_dict for saving
    
    print("Finish Training")
