import torch
import torch.optim as optim
import torchvision
# import torchsnooper
import math
import csv
import sys
sys.path.append('../')

import time
import matplotlib.pyplot as plt
from time import gmtime, strftime

import argparse
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from optimizers import *
from sgd_cam_hd import *
from adam_cam_hd import *

from train_hd import *
from models_hd import *
from utils import *

from enum import IntEnum
import inspect

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", device)

# LeNet-5: MNIST, CIFAR-10

args_HD = {"seed":0, "dataset":"MNIST", "patience":10, "lr":0.001, "beta":0.000001, "momentum":0.9, "weightDecay":0.0001,
       "batch_size":64, "epochs":2, "iterations":0, "lossThreshold":0, "silent":"layer_wise", "workers":4, 
        "parallel":"store_true", "save":"store_true",  "num_outputs":10, "new_acts": False,
       "loss_func":"NLL", "method":"sgd_hd", "level":"global", "cuda":False, "device":device}

args = args_HD
args["cuda"] = False
args["flexible_act"] = False
args["task"] = "CLA"

args["seed"] = np.random.randint(1000)
print("seed:", args["seed"])

def args_update(args):
    print("method", args["method"])

    if args["method"] == "adam_cam_hd" or args["method"]=="adam_hd" or args["method"]=="adam":
        print("adam")
        args.update({"betas":(0.9, 0.999), "eps":1e-8, "weight_decay":0.0005, "hypergrad_lr":1e-7,  "gamma":0.5, "delta": 0.01,
                 "reg_lr_layer":0.000, "reg_lr_unit":0.000, "reg_lr_para":0.000, "reg_lr_ts":0.0000, "kappa":0.000,
                "gamma_1":0, "gamma_2":0, "gamma_3":1})
    return args

torch.manual_seed(args["seed"])

if args["cuda"]:
    torch.cuda.set_device(0)
    torch.cuda.manual_seed(args["seed"])
    torch.backends.cudnn.enabled = True

args["model"] = "FFNN"
args["model_type"] = args["model"]
data_list = ["MNIST"]
method_base_list = ["adam"] 

method_list = ["adam_cam_hd", "adam_cam_hd", "adam_hd", "adam"]
meta_list = ["", "", "", ""] 
meta_para_list = [[0.01, 0.3, 0.3, 0.4], [0.01, 0.5, 0.5], [0, 0, 0], [0, 0, 0]]
meta_level_list = ["para_layer_global", "layer_global", "layer_global", ""]

hypergrad_lr_list = [1e-7, 1e-7, 1e-9, 0]
args["print_lr"] = False
args["hd_decay"] = False
args["hd_decay_coef"] = 0.000
L_mb = len(method_base_list)

L_meta_list = len(method_list)
n_trials = 10
args["epochs"] = 30
args["patience"] = 30
args["print_gamma"] = False
args["lr_schedule"] = None
args["update_from_combination"] = False

lr_list = [0.0003, 0.0003, 0.0003, 0.0003]
args["layer_size"] = [784, 100, 100, 10]
args["acts"] = ["relu", "relu", "sigmoid"]

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 10000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.5*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len]) 
            
            data = [train_dataset_1, val_dataset_1, test_dataset_1] 

            for p in range(L_meta_list):

                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]

                args["method"] = method_list[p]
                print("method", args["method"])
                args = args_update(args)
                args["lr"] = lr_list[p]
                args["momentum"] = 0.9
                args["nesterov"] = False
                args["hypergrad_lr"] = hypergrad_lr_list[p]
                args["level"] = meta_level_list[p] 
                args["batch_size"] = 32
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                try:
                    args["gamma_3"] = meta_para_list[p][3]
                except:
                    args["gamma_3"] = 0
                args["meta"] = meta_list[p]

                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)
                print("avg_valid_losses", avg_valid_losses)
                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()
                
                
# feed forward neural network with hidden size [1000, 100]  
lr_list = [0.001, 0.001, 0.001, 0.001]
args["layer_size"] = [784, 1000, 100, 10]
args["acts"] = ["relu", "relu", "sigmoid"]

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 10000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.5*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len]) 
            
            data = [train_dataset_1, val_dataset_1, test_dataset_1] 

            for p in range(L_meta_list):

                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]

                args["method"] = method_list[p]
                print("method", args["method"])
                args = args_update(args)
                args["lr"] = lr_list[p]
                args["momentum"] = 0.9
                args["nesterov"] = False
                args["hypergrad_lr"] = hypergrad_lr_list[p]
                args["level"] = meta_level_list[p] 
                args["batch_size"] = 64
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                try:
                    args["gamma_3"] = meta_para_list[p][3]
                except:
                    args["gamma_3"] = 0
                args["meta"] = meta_list[p]

                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)
                print("avg_valid_losses", avg_valid_losses)
                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()
                
# feed forward neural network with hidden size [1000, 1000]                
lr_list = [0.001, 0.001, 0.001, 0.001]
args["layer_size"] = [784, 1000, 1000, 10]
args["acts"] = ["relu", "relu", "sigmoid"]

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 10000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.5*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len]) 
            
            data = [train_dataset_1, val_dataset_1, test_dataset_1] 

            for p in range(L_meta_list):

                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]

                args["method"] = method_list[p]
                print("method", args["method"])
                args = args_update(args)
                args["lr"] = lr_list[p]
                args["momentum"] = 0.9
                args["nesterov"] = False
                args["hypergrad_lr"] = hypergrad_lr_list[p]
                args["level"] = meta_level_list[p] 
                args["batch_size"] = 128
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                try:
                    args["gamma_3"] = meta_para_list[p][3]
                except:
                    args["gamma_3"] = 0
                args["meta"] = meta_list[p]

                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)
                print("avg_valid_losses", avg_valid_losses)
                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()

