import numpy as np
import os
from datetime import datetime
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import cvxpy as cp
from torch.utils.data import Dataset, DataLoader
import time
from sklearn.model_selection import train_test_split

def one_hot_encoding(a):
    # Input:
    #     a: N-dim array with C classes
    #Output:
    #     b: one-hot encoding matrix [dim: N*C]
    b = np.zeros((a.size, a.max() + 1))
    b[np.arange(a.size), a] = 1
    return b

def one_hot_decoding(a):
    _, pred = torch.max(a,1)
    return pred


dtype = torch.float
device = torch.device("cpu")
torch.manual_seed(1102)
torch.cuda.manual_seed(1102)
np.random.seed(1)

def SAA_training(train_data, theta, Lambda, 
                 silence=False, 
                 test_iterations=100, 
                 maxiter = 2, 
                 step_size0 = 1e-2, 
                 tol=1e-2,
                 batch_size = 100):
    #   Input:
    # Feature: N samples of R^d [dim: N*d]
    #  Target: labels of N samples [dim: N*C]
    #       B: initial guess for optimization
    #  Output:
    #       B: classifers (w1,...,w_C) [dim: d*C]
    iter = 0
    #criterion = torch.nn.BCELoss()
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    theta = Variable(to_tensor(theta), requires_grad=False)
    theta_hat = Variable(theta.detach().clone(), requires_grad=True)

    theta_hat_avg = torch.zeros_like(theta_hat)
    for epoch in range(maxiter):
        for data, target in train_dataloader:
            # obtain gradient oracle
            #haty = torch.squeeze(torch.sigmoid(data @ theta_hat))
            haty = data @ theta_hat
            obj_vec = -torch.sum(target * haty, 1) + torch.log(torch.sum(torch.exp(haty),1))
            loss_theta_hat = torch.mean(obj_vec) + Lambda/2 * torch.sum((theta_hat - theta)**2)
            grad_theta_hat = torch.autograd.grad(loss_theta_hat, theta_hat)[0]

            theta_hat = theta_hat - step_size0 * grad_theta_hat

            iter += 1

            if (silence == False) and (iter % test_iterations == 0):
                print("Iter: {}, Gradnorm: {:.2f}, Loss: {:.2f}".format(iter, torch.linalg.norm(grad_theta_hat), loss_theta_hat.item()))
            
            theta_hat_avg = theta_hat_avg * (1 - 1/(iter+1)) + 1/(iter+1) * theta_hat.detach().clone()


        if torch.linalg.norm(grad_theta_hat) <= tol:
            break
    return theta_hat_avg

####################################################################
############## Convert to Torch Tensor Type ########################
####################################################################

def to_tensor(x):
    if type(x) == np.ndarray:
        return torch.from_numpy(x).float()
    elif type(x) == torch.Tensor:
        return x
    else:
        print("Type error. Input should be either numpy array or torch tensor")

####################################################################
############## Convert to Numpy Array Type #########################
####################################################################
    
def to_numpy(x):
    if type(x) == np.ndarray:
        return x
    else:
        try:
            return x.data.numpy()
        except:
            return x.cpu().data.numpy()

class MNIST_dataset(Dataset):
    def __init__(self, data, num_task = 5, train_size_ratio=0.9):
        """
        num_task: number of tasks in this dataset
            data:   list of two datasets
                    - [Feature], [Target]
        """
        self.data             = data
        self.num_task         = num_task
        self.train_size_ratio = train_size_ratio
        self.task_data = self.task_generation(self.data, self.num_task)

        super(MNIST_dataset, self).__init__()
    
    def task_generation(self, data, num_task):
        generated_task = []
        Feature = data["Feature"]
        Target  = data["Target"]
        for i in range(num_task):
            x_all = np.concatenate((Feature[Target == 2*i,:], Feature[Target == 2*i+1,:]), axis=0)
            y_all = np.concatenate((Target[Target == 2*i], Target[Target == 2*i+1]), axis=0)

            y_all[y_all == 2*i + 1] = 1
            y_all[y_all == 2*i]     = 0

            x_Tr, x_Te, y_Tr, y_Te = train_test_split(x_all, y_all, test_size=0.1, random_state=42)

            task = dict(
                        x_Tr  = to_tensor(x_Tr),
                        y_Tr  = to_tensor(y_Tr),
                        x_Te  = to_tensor(x_Te),
                        y_Te  = to_tensor(y_Te),
                        x_all = to_tensor(x_all),
                        y_all = to_tensor(y_all)
                        )
            generated_task.append(task)
        return generated_task
    def __len__(self):
            """
            Should return the number of elements (i.e. tasks) in the dataset
            To be used with data loader
            """
            return self.num_task

    def __getitem__(self, index):
        """
        Should return a task
        To be used with data loader
        """
        return self.task_data[index]
    
class cifar_dataset(Dataset):
    def __init__(self, data, num_task = 20, train_size_ratio=0.9):
        """
        num_task: number of tasks in this dataset
            data:   list of two datasets
                    - [Feature], [Target]
        """
        self.data             = data
        self.num_task         = num_task
        self.train_size_ratio = train_size_ratio
        self.task_data = self.task_generation(self.data, self.num_task)

        super(cifar_dataset, self).__init__()
    
    def task_generation(self, data, num_task):
        generated_task = []
        Feature = data["Feature"]
        Target  = data["Target"]
        for i in range(num_task):
            x_all = Feature[Target // 10 == i,:]
            y_all = Target[Target  // 10 == i]

            y_all_list = [y_all[i]%10 for i in range(len(y_all))]
            y_all = one_hot_encoding(np.array(y_all_list))
            # print(y_all)
            x_Tr, x_Te, y_Tr, y_Te = train_test_split(x_all, y_all, test_size=0.1, random_state=42)


            task = dict(
                        x_Tr  = to_tensor(x_Tr),
                        y_Tr  = to_tensor(y_Tr),
                        x_Te  = to_tensor(x_Te),
                        y_Te  = to_tensor(y_Te),
                        x_all = to_tensor(x_all),
                        y_all = to_tensor(y_all)
                        )
            generated_task.append(task)
        return generated_task
    def __len__(self):
            """
            Should return the number of elements (i.e. tasks) in the dataset
            To be used with data loader
            """
            return self.num_task

    def __getitem__(self, index):
        """
        Should return a task
        To be used with data loader
        """
        return self.task_data[index]
    
####################################################################
############## Epoch SGD ###########################################
####################################################################

def epoch_SGD(train_data_loader, theta, theta_hat, K_sample, Lambda, theta_step_size0):
    # epoch SGD step
    #       Input:
    #            X_Tr: feature vector,  dim: [N*d]
    #            y_Tr: label vector,    dim: [N]
    #       theta_hat: initial guess of optimal solution at lower level, dim: [d*1]
    #        K_sample: outer iteration of E-SGD
    #          Lambda: penalty in objective at inner leve, scalar
    #theta_step_size0: initial step size

    theta_hat = Variable(to_tensor(theta_hat), requires_grad=True)
    theta_10 = theta_hat.detach().clone()
    if K_sample <= 1:
        theta_K0 = theta_hat.detach().clone()
    criterion = torch.nn.CrossEntropyLoss()

    for k in range(K_sample):
        theta_hat_avg = torch.zeros_like(theta_hat)
        theta_step_sizek = theta_step_size0 / (2**(k+1))
        
        for j in range(2**k):
            X_Tr_j, y_Tr_j = next(iter(train_data_loader))

            outputs = X_Tr_j @ theta_hat
            # print(outputs.shape)
            # print(y_Tr_j.shape)
            obj_vec = -torch.sum(y_Tr_j * outputs) + torch.log(torch.sum(torch.exp(outputs)))
            loss_theta = torch.mean(obj_vec) + Lambda/2 * torch.sum((theta_hat - theta)**2)
            grad_theta = torch.autograd.grad(loss_theta, theta_hat)[0]

            theta_hat = theta_hat - theta_step_sizek * grad_theta
            theta_hat_avg = theta_hat_avg * (1 - 1/(j+1)) + 1/(j+1) * theta_hat.detach().clone()

        theta_hat = Variable(to_tensor(theta_hat_avg), requires_grad=True)
        if k == K_sample - 2:
            theta_K0 = theta_hat.detach().clone()
        if k == K_sample - 1:
            theta_K10 = theta_hat.detach().clone()
    return theta_10, theta_K0, theta_K10


# def adjust_lr_zt(optimizer, lr0, epoch):
#     lr = lr0 * (1.0 / 2**(epoch))
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr      

####################################################################
############## Estimate gradient of obj at outer level #############
####################################################################

def gradient_oracle_VSGD(train_data, test_data, 
                         theta_inner, theta, Lambda, 
                         N_2=100, Nmax=10, L_g_2=10):
    # Compute $\Lambda * nabla_2f$ based on the optimization
    #   min_w 0.5 * quad(w, Lambda) -w.T*nabla_2f
    #       Input:
    #            X_Tr: feature vector,  dim: [N*d]
    #            y_Tr: label vector,    dim: [N]
    #            X_Te: feature vector,  dim: [n*d]
    #            y_Te: label vector,    dim: [n*d]
    #     theta_inner: optimal solution at inner level,  dim: [d*1]
    #           theta: feasible solution at outer level, dim: [d*1]
    #            Nmax: maximum batch size for Hessian inverse estimator
    #           L_g_2: scaling parameter for Hessian inverse estimator

    X_Te, y_Te = test_data[:]
    theta_inner_var = Variable(to_tensor(theta_inner.detach().clone()), requires_grad=True)
    theta_var       = Variable(to_tensor(theta.detach().clone()), requires_grad=True)

    # outputs_Te_inner_var = X_Te @ theta_inner_var
    # Loss_outer_vec = -torch.sum(y_Te * outputs_Te_inner_var, 1) + torch.log(torch.sum(torch.exp(outputs_Te_inner_var),1))
    # Loss_outer = torch.mean(Loss_outer_vec)
    Loss_outer = formulate_outer_loss(X_Te, y_Te, theta_inner_var)
    nabla_2_f = torch.autograd.grad(Loss_outer, theta_inner_var)[0]
    #print(nabla_2_f)
    N_prime = np.random.randint(Nmax)
    indices = np.random.choice(np.arange(len(train_data)), size=N_prime)
    X_Tr_3, y_Tr_3 = train_data[indices]
    w_hat = Hessian_inv_vec_solver(X_Tr_3, y_Tr_3, theta_inner_var, theta_var, nabla_2_f, Lambda, Nmax, L_g_2)

    indices = np.random.choice(np.arange(len(train_data)), size=N_2)
    X_Tr_2, y_Tr_2 = train_data[indices]
    nabla_g_12_w_mean = get_nabla_g_12_w_solver(X_Tr_2, y_Tr_2, theta_inner_var, theta_var, w_hat, Lambda)
    #print(nabla_g_12_w_mean)
    return -nabla_g_12_w_mean, Loss_outer.item()

def Hessian_inv_vec_solver(X_Tr_3, y_Tr_3, theta_inner, theta, nabla_2_f, Lambda, Nmax, L_g_2):
    # Compute $\Lambda * nabla_2f$ based on the simulation
    #       Input:
    #            X_Tr: feature vector,  dim: [N*d]
    #            y_Tr: label vector,    dim: [N]
    #     theta_inner: optimal solution at inner level,  dim: [d*1]
    #           theta: feasible solution at outer level, dim: [d*1]
    #       nabla_2_f: vector to be multiplied
    #          Lambda: penalty at objective function
    #            Nmax: maximum number of Hessian sampling budget
    #           L_g_2: shinkrage parameter
    #criterion = torch.nn.BCELoss()
    N_prime = len(y_Tr_3)
    if N_prime < 1:
        r_N_prime = Nmax/L_g_2 * nabla_2_f
    else:
        r = nabla_2_f.clone()
        for n in range(N_prime):
            X_Tr, y_Tr = X_Tr_3[n,:], y_Tr_3[n]

            outputs = X_Tr @ theta_inner#torch.sigmoid(X_Tr @ theta_inner)
            Loss_g_vec = -torch.sum(y_Tr * outputs) + torch.log(torch.sum(torch.exp(outputs)))
            Loss_g = torch.mean(Loss_g_vec) + Lambda/2 * torch.sum((theta_inner - theta)**2)
            
            
            
            #Loss_g = torch.mean(((X_Tr @ theta_inner).reshape([-1,]) - y_Tr)**2)  + Lambda/2 * torch.sum((theta_inner - theta)**2)
            nabla_g_2 = torch.autograd.grad(Loss_g, theta_inner, create_graph=True)[0]
            q = 1/L_g_2 * torch.sum(nabla_g_2 * r)
            q_2 = torch.autograd.grad(q, theta_inner)[0]
            r = r - q_2

        r_N_prime = Nmax/L_g_2 * r
    return r_N_prime

def formulate_outer_loss(X_Te, y_Te, theta):
    outputs_Te_inner_var = X_Te @ theta
    Loss_outer_vec = -torch.sum(y_Te * outputs_Te_inner_var, 1) + torch.log(torch.sum(torch.exp(outputs_Te_inner_var),1))
    return torch.mean(Loss_outer_vec)

def gradient_oracle_RTMLMC(train_data, test_data, K_sample,
                                 theta_10, theta_K0, theta_K10, p_K, theta, 
                                 Lambda, N_2=100, Nmax=10, L_g_2=10):
    # Compute $\Lambda * nabla_2f$ based on the optimization
    #   min_w 0.5 * quad(w, Lambda) -w.T*nabla_2f
    #       Input:
    #            X_Tr: feature vector,  dim: [N*d]
    #            y_Tr: label vector,    dim: [N]
    #            X_Te: feature vector,  dim: [n*d]
    #            y_Te: label vector,    dim: [n*d]
    #     theta_inner: optimal solution at inner level, dim: [d*1]

    X_Te, y_Te = test_data[:]

    theta_10 = Variable(to_tensor(theta_10.detach().clone()), requires_grad=True)
    if K_sample > 1:
        theta_K0 = Variable(to_tensor(theta_K0.detach().clone()), requires_grad=True)
    theta_K10 = Variable(to_tensor(theta_K10.detach().clone()), requires_grad=True)
    theta    = Variable(to_tensor(theta.detach().clone()), requires_grad=True)
    

    Loss_outer_10  = formulate_outer_loss(X_Te, y_Te, theta_10)
    if K_sample > 1:
        Loss_outer_K0  = formulate_outer_loss(X_Te, y_Te, theta_K0)
    else:
        Loss_outer_K0 = Loss_outer_10.clone()
    Loss_outer_K10 = formulate_outer_loss(X_Te, y_Te, theta_K10)

    nabla_2_f_10 = torch.autograd.grad(Loss_outer_10, theta_10)[0]
    if K_sample > 1:
        nabla_2_f_K0 = torch.autograd.grad(Loss_outer_K0, theta_K0)[0]
    nabla_2_f_K10 = torch.autograd.grad(Loss_outer_K10, theta_K10)[0]

    N_prime = np.random.randint(Nmax)
    indices = np.random.choice(np.arange(len(train_data)), size=N_prime)
    X_Tr_3, y_Tr_3 = train_data[indices]
    w_10 = Hessian_inv_vec_solver(X_Tr_3, y_Tr_3, theta_10, theta, nabla_2_f_10, Lambda, Nmax, L_g_2)
    if K_sample > 1:
        w_K0 = Hessian_inv_vec_solver(X_Tr_3, y_Tr_3, theta_K0, theta, nabla_2_f_K0, Lambda, Nmax, L_g_2)
    w_K10 = Hessian_inv_vec_solver(X_Tr_3, y_Tr_3, theta_K10, theta, nabla_2_f_K10, Lambda, Nmax, L_g_2)

    indices = np.random.choice(np.arange(len(train_data)), size=N_2)
    X_Tr_2, y_Tr_2 = train_data[indices]
    nabla_g_12_w_10_mean  = get_nabla_g_12_w_solver(X_Tr_2, y_Tr_2, theta_10, theta, w_10, Lambda)
    if K_sample > 1:
        nabla_g_12_w_K0_mean  = get_nabla_g_12_w_solver(X_Tr_2, y_Tr_2, theta_K0, theta, w_K0, Lambda)
    else:
        nabla_g_12_w_K0_mean = nabla_g_12_w_10_mean.clone()
    nabla_g_12_w_K10_mean = get_nabla_g_12_w_solver(X_Tr_2, y_Tr_2, theta_K10, theta, w_K10, Lambda)

    nabla_g_12_w_mean = nabla_g_12_w_10_mean + (nabla_g_12_w_K10_mean - nabla_g_12_w_K0_mean)/p_K
    Loss_outer = Loss_outer_10.item() + (Loss_outer_K10.item() - Loss_outer_K0.item())/p_K

    return -nabla_g_12_w_mean, Loss_outer

####################################################################
############## Estimating nabla_12 @ w product #####################
####################################################################

def get_nabla_g_12_w_solver(X_Tr, y_Tr, theta_inner, theta, w, Lambda):
    # Compute $\nabla_12 * w$ based on sample average approximation
    #       Input:
    #            X_Tr: feature vector,  dim: [N*d]
    #            y_Tr: label vector,    dim: [N]
    #     theta_inner: optimal solution at inner level,  dim: [d*1]
    #           theta: feasible solution at outer level, dim: [d*1]
    #               w: vector to be multiplied
    #criterion = torch.nn.BCELoss()
    outputs = X_Tr @ theta_inner#torch.sigmoid(X_Tr @ theta_inner)
    Loss_g_vec = -torch.sum(y_Tr * outputs, 1) + torch.log(torch.sum(torch.exp(outputs),1))
    Loss_g = torch.mean(Loss_g_vec) + Lambda/2 * torch.sum((theta_inner - theta)**2)
    #Loss_g = criterion(torch.squeeze(outputs), y_Tr) + Lambda/2 * torch.sum((theta_inner - theta)**2)
    nabla_g_2 = torch.autograd.grad(Loss_g, theta_inner, create_graph=True)[0]
    nabla_g_2_w = torch.sum(nabla_g_2 * w)
    nabla_g_12_w = torch.autograd.grad(nabla_g_2_w, theta)[0]
    return nabla_g_12_w


def performance_Te_eval(theta, cifar_dataset, Lambda):
    #criterion = torch.nn.BCELoss()

    num_task = cifar_dataset.__len__()
    logistic_loss_mean           = 0
    mis_classification_loss_mean = 0
    for i in range(num_task):
        dataset_i = cifar_dataset.__getitem__(i)
        x_Tr_i = to_tensor(dataset_i["x_Tr"])
        y_Tr_i = to_tensor(dataset_i["y_Tr"])
        x_Te_i = to_tensor(dataset_i["x_Te"])
        y_Te_i = to_tensor(dataset_i["y_Te"])

        y_Te_i_decode = one_hot_decoding(y_Te_i)


        train_data = torch.utils.data.TensorDataset(
                            x_Tr_i, y_Tr_i)
        theta_hat = SAA_training(train_data, theta, Lambda,silence=True,maxiter = 1)
        #_, _, theta_K10 = epoch_SGD(train_data, theta, 5, Lambda, 1e-2)

        outputs_Te_inner_var = x_Te_i @ theta_hat

        _, pred = torch.max(torch.exp(outputs_Te_inner_var),1)


        logistic_loss_vec = -torch.sum(y_Te_i * outputs_Te_inner_var, 1) + torch.log(torch.sum(torch.exp(outputs_Te_inner_var),1))
        logistic_loss = torch.mean(logistic_loss_vec)
        mis_classification_loss = torch.sum(pred.long() != y_Te_i_decode) / len(y_Te_i_decode)
        



        logistic_loss_mean += logistic_loss
        mis_classification_loss_mean += mis_classification_loss

    logistic_loss_mean = logistic_loss_mean/num_task
    mis_classification_loss_mean = mis_classification_loss_mean / num_task
    return logistic_loss_mean.item(), mis_classification_loss_mean.item()