import torch
import torch.optim as optim
import torchvision
import torchsnooper
import math
import csv
import sys
sys.path.append('../')

import time
import matplotlib.pyplot as plt
from time import gmtime, strftime

import argparse
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from optimizers import *
from sgd_cam_hd import *
from adam_cam_hd import *

from train_hd import *
from models_hd import *
from utils import *

from enum import IntEnum
import inspect

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", device)

# LeNet-5: MNIST, CIFAR-10, SVHN

args_HD = {"seed":0, "dataset":"MNIST", "model":"lenet_5", "patience":10, "lr":0.1, "beta":0.000001, "momentum":0.9, "weightDecay":0.0001,
       "epochs":2, "iterations":0, "lossThreshold":0, "silent":"layer_wise", "workers":4, 
        "parallel":"store_true", "save":"store_true",  "num_outputs":10, "new_acts": False,
       "loss_func":"NLL", "method":"sgd_hd", "level":"global", "cuda":False, "device":device}

args = args_HD
args["cuda"] = False
args["flexible_act"] = False
args["task"] = "CLA"

args["seed"] = np.random.randint(1000)
print("seed:", args["seed"])

def args_update(args):

    if args["method"] == "adam_cam_hd" or args["method"]=="adam_hd" or args["method"]=="adam":
        args.update({"lr":0.001, "betas":(0.9, 0.999), "eps":1e-8, "weight_decay":1e-4, "hypergrad_lr":1e-8,  "gamma":0.5, "delta": 0.01,
                 "reg_lr_layer":0.000, "reg_lr_unit":0.000, "reg_lr_para":0.000, "reg_lr_ts":0.0000, "kappa":0.000,
                "gamma_1":0, "gamma_2":0, "gamma_3":1})
    return args

torch.manual_seed(args["seed"])

if args["cuda"]:
    torch.cuda.set_device(0)
    torch.cuda.manual_seed(args["seed"])
    torch.backends.cudnn.enabled = True
    
args["device"] = 0
args["model_type"] = "lenet_5"
data_list = ["MNIST"]
method_base_list = ["adam"]

args["print_lr"] = False
args["hd_decay"] = False
args["hd_decay_coef"] = 0.002

L_mb = len(method_base_list)
method_list = ["_cam_hd", "_hd", "", ""]
meta_list = ["", "", "L4", ""]
meta_para_list = [[0.03, 0.2, 0.8], [0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0]]

L_meta_list = len(meta_para_list)
n_trials = 10
args["epochs"] = 10
args["patience"] = 30
args["print_gamma"] = False
args["update_from_combination"] = False
args["lr_schedule"] = None
# args["meta"] = ""

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 5000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.9*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len])                                 
        
            print("train_len", len(train_dataset_1))
            print("val_len", len(val_dataset_1))
            print("test_len", len(test_dataset_1))
            
            args["batch_size"] = 256
        
            # Data loader
            train_loader = torch.utils.data.DataLoader(dataset=train_dataset_1, 
                                           batch_size=args["batch_size"], 
                                           shuffle=True)

            val_loader = torch.utils.data.DataLoader(dataset=val_dataset_1, 
                                          batch_size=args["batch_size"], 
                                          shuffle=False)

            test_loader = torch.utils.data.DataLoader(dataset=test_dataset_1, 
                                          batch_size=args["batch_size"], 
                                          shuffle=False)
            data = [train_loader, val_loader, test_loader] 

            for p in range(L_meta_list):
            
                # args["method"] = "adam_hd_1"
                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]
                # args["lr"] = 1e-3
                args["method"] = method_base + method # "_hd_1"
                print("method", args["method"])
                args = args_update(args)
                args["momentum"] = 0
                args["nesterov"] = False
                args["meta"] = meta_list[p]
                
                args["level"] = "layer_global"
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)
                print("avg_valid_losses", avg_valid_losses)
                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()
                
args["device"] = 0
args["model_type"] = "lenet_5"
data_list = ["CIFAR10"]
method_base_list = ["adam"]
args["print_lr"] = False
args["hd_decay"] = False
args["hd_decay_coef"] = 0.002

L_mb = len(method_base_list)
method_list = ["_hd_1", "_hd", "", ""]
meta_list = ["", "", "L4", ""]
meta_para_list = [[0.03, 0.2, 0.8], [0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0]]

L_meta_list = len(meta_para_list)
n_trials = 10
args["epochs"] = 10
args["patience"] = 30
args["print_gamma"] = False
args["update_from_combination"] = False
args["lr_schedule"] = None

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 5000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.9*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len])                                 
        
            print("train_len", len(train_dataset_1))
            print("val_len", len(val_dataset_1))
            print("test_len", len(test_dataset_1))
            
            args["batch_size"] = 256
        
            # Data loader
            train_loader = torch.utils.data.DataLoader(dataset=train_dataset_1, 
                                           batch_size=args["batch_size"], 
                                           shuffle=True)

            val_loader = torch.utils.data.DataLoader(dataset=val_dataset_1, 
                                          batch_size=args["batch_size"], 
                                          shuffle=False)

            test_loader = torch.utils.data.DataLoader(dataset=test_dataset_1, 
                                          batch_size=args["batch_size"], 
                                          shuffle=False)
            data = [train_loader, val_loader, test_loader] 

            for p in range(L_meta_list):
            
                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]
                args["method"] = method_base + method 
                print("method", args["method"])
                args = args_update(args)
                args["momentum"] = 0
                args["nesterov"] = False
                args["meta"] = meta_list[p]
                
                args["level"] = "layer_global"
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]

                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)
                print("avg_valid_losses", avg_valid_losses)
                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()
                
args["device"] = 0
args["model_type"] = "lenet_5"
data_list = ["SVHN"]
method_base_list = ["adam"]

args["print_lr"] = False
args["hd_decay"] = False
args["hd_decay_coef"] = 0.002

L_mb = len(method_base_list)
method_list = ["_cam_hd", "_hd", "", ""]
meta_list = ["", "", "L4", ""]
meta_para_list = [[0.03, 0.2, 0.8], [0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0]]

L_meta_list = len(meta_para_list)
n_trials = 5
args["epochs"] = 10
args["patience"] = 30
args["print_gamma"] = False
args["update_from_combination"] = False
args["lr_schedule"] = None

for data_set in data_list:
    
    args["dataset"] = data_set
    train_dataset, test_dataset = data_prepare(args)
    
    args["seed"] = np.random.randint(1000)
    torch.manual_seed(args["seed"])
    print("learning rate", args["lr"])
    
    for method_base in method_base_list:
    
        eva_list_list = [[] for i in range(L_meta_list)]
        val_losses_table_list = [[] for i in range(L_meta_list)]
        train_losses_table_list = [[] for i in range(L_meta_list)]
        val_accs_table_list = [[] for i in range(L_meta_list)]
        gamma_table_list = []
    
        for i in range(n_trials):
            
            print("len(test_dataset)", len(test_dataset))
            val_len = 5000 # int(0.2*len(train_dataset))
            test_len = int(len(test_dataset)) - val_len
            val_dataset_1, test_dataset_1 = torch.utils.data.random_split(test_dataset, lengths=[val_len, test_len])
    
            train_len = len(train_dataset)
            train_lost_len = 0 #int(0.9*train_len)  # len(train_dataset_1))
            train_len_1 = train_len - train_lost_len
            train_dataset_1, lost_dataset = torch.utils.data.random_split(train_dataset, lengths=[train_len_1, train_lost_len])                                 
        
            print("train_len", len(train_dataset_1))
            print("val_len", len(val_dataset_1))
            print("test_len", len(test_dataset_1))
            
            args["batch_size"] = 128
        
            # Data loader
            train_loader = torch.utils.data.DataLoader(dataset=train_dataset_1, 
                                           batch_size=args["batch_size"], 
                                           shuffle=True)

            val_loader = torch.utils.data.DataLoader(dataset=val_dataset_1, 
                                          batch_size=args["batch_size"], 
                                          shuffle=False)

            test_loader = torch.utils.data.DataLoader(dataset=test_dataset_1, 
                                          batch_size=args["batch_size"], 
                                          shuffle=False)
            data = [train_loader, val_loader, test_loader] 

            for p in range(L_meta_list):
            
                args["seed"] = np.random.randint(1000)
                torch.manual_seed(args["seed"])
                print("learning rate", args["lr"])
                method = method_list[p]
                args["method"] = method_base + method 
                print("method", args["method"])
                args = args_update(args)
                args["momentum"] = 0
                args["nesterov"] = False
                args["meta"] = meta_list[p]
                
                args["level"] = "layer_global"
                args["delta"] = meta_para_list[p][0]
                args["gamma_1"] = meta_para_list[p][1]
                args["gamma_2"] = meta_para_list[p][2]
                print("method", args["method"], meta_para_list[p])
                eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list = model_train_HD(data, args)
                try:
                    eva = eva.item() 
                except:
                    eva = eva
                print("eva:", eva)
                print("avg_valid_losses", avg_valid_losses)
                eva_list_list[p].append(eva)
                val_losses_table_list[p].append(avg_valid_losses)
                val_accs_table_list[p].append(avg_valid_accs)
                train_losses_table_list[p].append(avg_train_losses)
                gamma_table_list.append(gamma_list_list)
                torch.cuda.empty_cache()