import os
import sys
sys.path.insert(0, './')
import json
import yaml
import pickle
import argparse

import time
import numpy as np
import math

import torch
import torch.nn as nn
from datetime import datetime

from alg.Vanilla import train
from util.DataParser import parse_data
from util.DeviceParser import parse_device
from util.ModelParser import parse_model
from util.OptimParser import parse_optim
from util.ParamParser import *
from util.SeqParser import parse_seq

from opacus import PrivacyEngine
from opacus.validators import ModuleValidator

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--yaml', type = str, help = 'The yaml file to load default setting.')

    parser.add_argument('--dataset', type = str, help = 'The dataset to use.')
    parser.add_argument('--valid_ratio', type = float, help = 'The ratio of validation set, None means no validation set.')
    parser.add_argument('--batch_size', type = int, help = 'The batch size')
    parser.add_argument('--mislabel_ratio', type = float, help = 'The ratio of mislabel data')
    parser.add_argument('--is_split', action = 'store_true', help = 'The split of the dataset.')
    parser.add_argument('--split_seed', type=int, default=42, help='Split seed.')
    parser.add_argument('--is_shadow', action='store_true', help='If training a shadow model.')
    parser.add_argument('--shadow_ratio', type=float, default=0.8)
    parser.add_argument('--shadow_seed', type=int, default=0, help='shadow split seed.')
    parser.add_argument('--arch', type = str, help = 'The model architecture.')
    parser.add_argument('--normalize', type = str, help = 'Whether the data is normalized, default is False.')

    parser.add_argument('--out_folder', type = str, help = 'The output folder.')
    parser.add_argument('--model_name', type = str, help = 'The name of the model.')

    parser.add_argument('--epoch_num', type = int, help = 'The number of epochs.')
    parser.add_argument('--model2load', type = str, help = 'The pretrained model to load.')
    parser.add_argument('--epoch2save', action = IntListParser, help = 'The list of checkpoints to save.')

    parser.add_argument('--optim', action = DictParser, help = 'The optimizer configuration.')
    parser.add_argument('--lr_schedule', action = DictParser, help = 'The learning rate scheduler.')

    parser.add_argument('--gpu', type = str, help = 'Specify the GPU to use.')
    parser.add_argument("--dp_sigma",type = float, help="noise variance for DP-SGD")
    parser.add_argument('--max_grad_norm', type=float, help = "Max grad norm for dp")

    args = parser.parse_args()

    # Default configuration
    config = {key: None for key, value in args._get_kwargs()}

    # Load YAML file
    with open(args.yaml, 'r') as fopen:
        yaml_config = yaml.safe_load(fopen)
    for key, value in yaml_config.items():
        if value == "None":
            value = None
        config[key] = value

    # Load command line config
    for key, value in vars(args).items():
        if value is not None:
            config[key] = value

    # Config GPUs
    parse_device(config['gpu'])
    use_gpu = config['gpu'] != 'cpu' and torch.cuda.is_available()
    device = torch.device('cuda:0' if use_gpu else 'cpu')

    # Config IO
    if not os.path.exists(config['out_folder']):
        os.makedirs(config['out_folder'])


    # Parse model and dataset
    train_loader, valid_loader, test_loader, classes = parse_data(
        name = config['dataset'], root=config["root"], batch_size = config['batch_size'], valid_ratio = config['valid_ratio'],
        mislabel_ratio = config['mislabel_ratio'], mislabel_seed = config['mislabel_seed'], is_split = config['is_split'],
        split_seed = config['split_seed'], is_shadow = config['is_shadow'], shadow_ratio = config['shadow_ratio'],
        class_subset_path = config['class_subset_path'], num_worker=config['num_worker'])

    model = parse_model(dataset = config['dataset'], arch = config['arch'], normalize = config['normalize'])

    criterion = nn.CrossEntropyLoss()
    model = model.cuda() if use_gpu else model
    criterion = criterion.cuda() if use_gpu else criterion
    if config['model2load'] is not None:
        model.load_state_dict(torch.load(config['model2load'], weights_only = True))

    # Parse the optimizer
    if config["dp_sigma"]is not None and config["dp_sigma"] != 0.0:
        errors = ModuleValidator.validate(model, strict=False)
        if errors:
            model = ModuleValidator.fix(model)


    optimizer = parse_optim(policy = config['optim'], params = model.named_parameters(),model=model)
    lr_func = parse_seq(**config['lr_schedule']) if config['lr_schedule'] is not None else None

    # DP-SGD attachment 

    if config["dp_sigma"]is not None and config["dp_sigma"] != 0.0:        
        privacy_engine = PrivacyEngine()
        model, optimizer, train_loader = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=train_loader,
            noise_multiplier=config["dp_sigma"],
            max_grad_norm=config['max_grad_norm'],
        )



    # Prepare the item to save
    tosave = {'model_summary': str(model), 'config': config, 'lr_per_epoch': {}, 'runtime': {},
        'log': {'cmd': 'python ' + ' '.join(sys.argv), 'time': datetime.now().strftime('%Y/%m/%d, %H:%M:%S')},
        'train_loss': {}, 'train_acc': {}, 'valid_loss': {}, 'valid_acc': {}, 'test_loss': {}, 'test_acc': {}}

    for key in list(sorted(config.keys())):
        print('%s\t=>%s' % (key, config[key]))

    train(model = model, train_loader = train_loader, valid_loader = valid_loader, test_loader = test_loader,
        epoch_num = config['epoch_num'], epoch2save = config['epoch2save'], optimizer = optimizer, lr_func = lr_func,
        out_folder = config['out_folder'], model_name = config['model_name'], device = device, criterion = criterion, tosave = tosave)
