import torch
import torch.optim as optim
import torchvision
import torchsnooper
import math
import csv
import sys
sys.path.append('../')

import time
import matplotlib.pyplot as plt
from time import gmtime, strftime

import argparse
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from optimizers import *
from sgd_cam_hd import *
from adam_cam_hd import *

from train_hd import *
from models_hd import *
from utils import *

from enum import IntEnum
import inspect

args_HD = {"seed":0, "dataset":"MNIST", "model":"lenet_5", "patience":10, "lr":0.1, "beta":0.000001, "momentum":0.9, "weight_decay":0.0001,
       "batch_size":128, "epochs":2, "iterations":0, "lossThreshold":0, "silent":"layer_wise", "workers":4, 
        "parallel":"store_true", "save":"store_true",  "num_outputs":10, "new_acts": False,
       "loss_func":"NLL", "method":"sgd_hd", "level":"global", "cuda":False, "device":device}

args = args_HD
args["cuda"] = True
args["flexible_act"] = False
args["task"] = "CLA"

args["seed"] = np.random.randint(1000)
print("seed:", args["seed"])

def args_update(args):
    print("method", args["method"])
    if args["method"] == "sgd" or args["method"] == "sgd_hd_1" or args["method"] == "sgd_hd_2" or args["method"] == "sgd_hd_3" or args["method"] == "sgd_hd":
        print("sgd")
        args.update({"momentum":0.9, "dampening":0, "weight_decay":3e-4, "nesterov":False, "gamma_1":0, "gamma_2":0, "gamma_3":1,
                     "hypergrad_lr":1e-6, "delta": 0.01, "reg_lr_layer":0.000, "reg_lr_unit":0.000, "reg_lr_para":0.000, 
                     "reg_lr_ts":0.01, "kappa":0.000})

    if args["method"] == "adam_hd_1" or args["method"]=="adam_hd_2" or args["method"]=="adam_hd_3" or args["method"]=="adam_hd" or args["method"]=="adam":
        print("adam")
        args.update({"betas":(0.9, 0.999), "eps":1e-8, "weight_decay":1e-4, "hypergrad_lr":1e-8,  "gamma":0.5, "delta": 0.01,
                 "reg_lr_layer":0.000, "reg_lr_unit":0.000, "reg_lr_para":0.000, "reg_lr_ts":0.0000, "kappa":0.000,
                "gamma_1":0, "gamma_2":0, "gamma_3":1})
    return args

torch.manual_seed(args["seed"])

if args["cuda"]:
    torch.cuda.set_device(0)
    torch.cuda.manual_seed(args["seed"])
    torch.backends.cudnn.enabled = True

# Adam, Adam-L4, Adam-HD, Adam-CAM-HD for ResNet-18
args["device"] = 0
args["model_type"] = "resnet_18"
data_list = ["CIFAR10"]
args["model"] = "ResNet"
method_base_list = ["adam"]

args["print_lr"] = False
args["hd_decay"] = True
args["hd_decay_coef"] = 0.001
args["lr"] = 0.001

meta_list = ["", "", "L4", ""]
L_mb = len(method_base_list)
method_list = ["_cam_hd", "_hd", "", ""]
meta_para_list = [[0.001, 0.2, 0.8], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]]

L_meta_list = len(meta_para_list)
n_trials = 10
args["epochs"] = 200
args["patience"] = 200
args["print_gamma"] = False
args["update_from_combination"] = False

def lr_schedule_func(epoch):
    
    if epoch == 150:
        return 0.1
    elif epoch == 250:
        return 0.1
    elif epoch == 350:
        return 0.1
    else:
        return 1

args["lr_schedule"] = None

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 10000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 # int(0.9*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len])                                 
            print("train_len", len(train_dataset_1))
            print("val_len", len(val_dataset_1))
            print("test_len", len(test_dataset_1))
        
            # Data loader

            data = [train_dataset_1, val_dataset_1, test_dataset_1]

            for p in range(L_meta_list):
            
                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]
                args["meta"] = meta_list[p]
                args["method"] = method_base + method 
                print("method", args["method"])
                args = args_update(args)
                args["momentum"] = 0.9
                args["nesterov"] = True
                args["weightDecay"] = args["weight_decay"]
                args["level"] = "layer_global"
                args["batch_size"] = 256
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)

                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()
                
# SGDN, SGDN-L4, SGDN-HD, SGDN-CAM-HD for ResNet-34
args["model_type"] = "resnet_34"
data_list = ["CIFAR10"]
method_base_list = ["sgd"]

args["print_lr"] = False
args["hd_decay"] = True
args["hd_decay_coef"] = 0.001
args["lr"] = 0.1

meta_list = ["", "", "L4", ""]
L_mb = len(method_base_list)
method_list = ["_cam_hd", "_hd", "", ""]
args["model"] = "ResNet"
meta_para_list = [[0.001, 0.2, 0.8], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
L_meta_list = len(meta_para_list)
n_trials = 10
args["epochs"] = 200
args["patience"] = 200
args["print_gamma"] = False
args["update_from_combination"] = False

# SGDN-CAM-HD
def lr_schedule_func(epoch):
    
    if epoch == 150:
        return 0.1
    elif epoch == 250:
        return 0.1
    elif epoch == 350:
        return 0.1
    else:
        return 1
  
# SGDN
def lr_schedule_func_1(epoch):
    
    if epoch < 150:
        return 0.1
    elif epoch < 250:
        return 0.01
    elif epoch < 350:
        return 0.001
    else:
        return 0

args["lr_schedule"] = lr_schedule_func 

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 10000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.9*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len])                                 
        
            print("train_len", len(train_dataset_1))
            print("val_len", len(val_dataset_1))
            print("test_len", len(test_dataset_1))
        
            # Data loader

            data = [train_dataset_1, val_dataset_1, test_dataset_1]

            for p in range(L_meta_list):
            
                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]
                args["meta"] = meta_list[p]
                args["method"] = method_base + method # "_hd_1"
                print("method", args["method"])
                args = args_update(args)
                args["momentum"] = 0.9
                args["nesterov"] = True
                args["weightDecay"] = args["weight_decay"]
                args["level"] = "layer_global"
                args["batch_size"] = 256
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)

                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()