import numpy as np
import os
from datetime import datetime
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import cvxpy as cp
from torch.utils.data import Dataset, DataLoader
import time
from sklearn.model_selection import train_test_split
import utils_meta_old as utils_meta

dtype = torch.float
device = torch.device("cpu")
torch.manual_seed(1102)
torch.cuda.manual_seed(1102)
np.random.seed(1)

input_method = input("Enter method: ")
K_sample_max = int(input("Kmax: "))
breakpoint   = int(input("Breakpoint: "))
num_trial    = int(input("Number of Trial: "))



Feature_Tr = np.load("imagenet/Feature_Tr.npy")
Feature_Te = np.load("imagenet/Feature_Te.npy")
Target_Tr  = np.load("imagenet/Target_Tr.npy")
Target_Te  = np.load("imagenet/Target_Te.npy")
_, d_x = np.shape(Feature_Tr)
data = dict(
            Feature = np.concatenate((Feature_Tr, Feature_Te), axis=0),
            Target  = np.concatenate((Target_Tr, Target_Te), axis=0)
           )
cifar_dataset = utils_meta.cifar_dataset(num_task=20, data=data)

Lambda = 2
theta_inner_step_size = 5e-2
theta_outer_step_size0 = 1e-2
batch_size = 100

torch.manual_seed(1102)
torch.cuda.manual_seed(1102)
np.random.seed(1)
residual_hist   = []
Loss_outer_hist = []
error_out_hist  = []
time_hist       = []



# sampling from truncated gemoetric distribution
p=0.5
elements = np.arange(K_sample_max)+1
probabilities = p ** (elements-1)
probabilities = probabilities / np.sum(probabilities)


num_epoch = 10000
num_task = cifar_dataset.__len__()

#if input_method == "VSGD":
    # theta_outer_step_size0 = 2e-2
    # if K_sample_max >= 8:
    #     theta_outer_step_size0 = 1e-2


for trial in range(num_trial):
    torch.manual_seed(1109+7*trial)
    torch.cuda.manual_seed(1109+7*trial)
    np.random.seed(8+7*trial)
    theta0 = torch.randn(d_x,10) * 1/np.sqrt(d_x)
    theta = theta0.clone()
    loss_Te_mean, error_Te_mean = utils_meta.performance_Te_eval(theta, cifar_dataset, Lambda)
    print([loss_Te_mean, error_Te_mean])

    for i in range(num_epoch):
        if i <= breakpoint:
            theta_outer_step_size = theta_outer_step_size0 / np.sqrt(1+i)
        else:
            theta_outer_step_size = theta_outer_step_size0 / (1+i)
        idx = np.random.randint(num_task)
        dataset_i = cifar_dataset.__getitem__(idx)
        x_Tr_i = dataset_i["x_Tr"]
        y_Tr_i = dataset_i["y_Tr"]
        x_Te_i = dataset_i["x_Te"]
        y_Te_i = dataset_i["y_Te"]

        train_data = torch.utils.data.TensorDataset(
            utils_meta.to_tensor(x_Tr_i), utils_meta.to_tensor(y_Tr_i))
        test_data = torch.utils.data.TensorDataset(
            utils_meta.to_tensor(x_Te_i), utils_meta.to_tensor(y_Te_i))
        train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

        X_Tr_j, y_Tr_j = next(iter(train_dataloader))


        if input_method == "VSGD":
            start_time = time.time()
            theta_hat = theta.clone()#torch.zeros_like(theta)
            theta_10, theta_K0, theta_K10 = utils_meta.epoch_SGD(train_data, theta, theta_hat, 
                                                                    K_sample_max, Lambda, theta_inner_step_size)
            nabla_theta_outer, Loss_outer = utils_meta.gradient_oracle_VSGD(train_data, test_data, 
                            theta_K10, theta, Lambda, 
                            N_2=120, Nmax=80, L_g_2=120)
            
            residual = torch.norm(nabla_theta_outer)
            theta = theta - theta_outer_step_size  * nabla_theta_outer
            running_time_iter = time.time() - start_time
            #print("---")

            if i % 100 == 0:
                loss_Te_mean, error_Te_mean = utils_meta.performance_Te_eval(theta, cifar_dataset, Lambda)
                Loss_outer_hist.append(loss_Te_mean)
                print("Task: {}, Grad_res: {:.2e}, loss: {:.2e}, time: {:.2e}".format(i, residual, \
                                                                            loss_Te_mean, running_time_iter))
            residual_hist.append(residual)
            time_hist.append(running_time_iter)
        
        if input_method == "RTMLMC":
            start_time = time.time()

            K_sample = int(np.random.choice(list(elements), 1, list(probabilities)))
            theta_hat = theta.clone()
            theta_10, theta_K0, theta_K10 = utils_meta.epoch_SGD(train_data, theta, theta_hat, 
                                                                    K_sample, Lambda, theta_inner_step_size)
        
            nabla_theta_outer, Loss_outer = utils_meta.gradient_oracle_RTMLMC(train_data, test_data, K_sample,
                                    theta_10, theta_K0, theta_K10, probabilities[K_sample-1], theta, Lambda, 
                                    N_2=120, Nmax=80, L_g_2=120)
            residual = torch.norm(nabla_theta_outer)
            theta = theta - theta_outer_step_size * nabla_theta_outer
            
            running_time_iter = time.time() - start_time

            if i % 100 == 0:
                loss_Te_mean, error_Te_mean = utils_meta.performance_Te_eval(theta, cifar_dataset, Lambda)
                Loss_outer_hist.append(loss_Te_mean)
                error_out_hist.append(error_Te_mean)
                print("Task: {}, K: {}, Grad_res: {:.2e}, loss: {:.2e}, error: {:.2e}, time: {:.2e}".format(i, K_sample, \
                                                                                residual, loss_Te_mean, error_Te_mean, running_time_iter))
            residual_hist.append(residual)
            time_hist.append(running_time_iter)





    # if input_method == "RTMLMC":
    #     np.save("new_results/residual_hist_meat_RTMLMC_"+str(K_sample_max)+"_"+str(trial)+".npy",   np.array(residual_hist))
    #     np.save("new_results/Loss_outer_hist_meat_RTMLMC_"+str(K_sample_max)+"_"+str(trial)+".npy", np.array(Loss_outer_hist))
    #     np.save("new_results/error_out_hist_meat_RTMLMC_"+str(K_sample_max)+"_"+str(trial)+".npy", np.array(error_out_hist))
    #     np.save("new_results/time_hist_meat_RTMLMC_"+str(K_sample_max)+"_"+str(trial)+".npy",       np.array(time_hist))
    # elif input_method == "VSGD":
    #     np.save("new_results/residual_hist_meat_VSGD_"+str(K_sample_max)+"_"+str(trial)+".npy",   np.array(residual_hist))
    #     np.save("new_results/Loss_outer_hist_meat_VSGD_"+str(K_sample_max)+"_"+str(trial)+".npy", np.array(Loss_outer_hist))
    #     np.save("new_results/time_hist_meat_VSGD_"+str(K_sample_max)+"_"+str(trial)+".npy",       np.array(time_hist))

