import os
import importlib
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import torchvision.transforms
import torch.utils.data.sampler as sampler
import matplotlib.pyplot as plt

def weights_init_uniform_rule(m):
    classname = m.__class__.__name__
    # for every Linear layer in a model..
    if classname.find('Linear') != -1:
        # get the number of the inputs
        n = m.in_features
        y = 1.0/np.sqrt(n)
        m.weight.data.uniform_(-y, y)
        m.bias.data.fill_(1)

def to_tensor(x):
    if type(x) == np.ndarray:
        return torch.from_numpy(x).float()
    elif type(x) == torch.Tensor:
        return x
    else:
        print("Type error. Input should be either numpy array or torch tensor")

def adjust_lr_zt(optimizer, lr0, epoch, breakpoint=50):
    if epoch <= breakpoint:
        lr = lr0 * (1.0 / np.sqrt(epoch))
    else:
        lr = lr0 * (1.0 / (epoch))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr      

####################################################################
############## NN Architecture #####################################
####################################################################
class myNN_train(nn.Module):
    def __init__(self):
        super(myNN_train, self).__init__()
        self.linear1 = nn.Linear(100,50) #input dimension:100
        self.linear2 = nn.Linear(50,25)
        self.linear3 = nn.Linear(25,1)
        self.Sigmoid = nn.Sigmoid()
        self.ELU     = nn.ELU()
    def forward(self,input):
        x = self.ELU(self.linear1(input))
        x = self.ELU(self.linear2(x))
        x = self.Sigmoid(self.linear3(x)) * 10
        return x

class myNN_train_test(nn.Module):
    def __init__(self):
        super(myNN_train_test, self).__init__()
        self.linear1 = nn.Linear(100,50) #input dimension:100
        self.linear2 = nn.Linear(50,25)
        self.linear3 = nn.Linear(25,1)
        self.Sigmoid = nn.Sigmoid()
        self.ELU     = nn.ELU()
    def forward(self,input):
        x = self.ELU(self.linear1(input))
        x = self.ELU(self.linear2(x))
        x = self.Sigmoid(self.linear3(x)) * 10
        return x

####################################################################
############## Epoch SGD ###########################################
####################################################################

def epoch_SGD(Feature_, Label_, K_sample, y_step_size0, model, h, b, beta0):
    # epoch SGD step
    #       Input:
    #    Feature_: feature vector, dim: [N*d]
    #      Label_: label vector,   dim: [N]
    #    K_sample: outer iteration of E-SGD
    #y_step_size0: initial step size
    #       model: NN model

    m = nn.Softplus(beta=beta0)
    Feature_, Label_ = Variable(Feature_), Variable(Label_)
    # generate initial guess y_1^0
    y_hat = Variable(Feature_.clone(), requires_grad=True)
    y_10 = y_hat.clone()
    if K_sample <= 1:
        y_K0 = y_hat.clone()

    optimizer_y = torch.optim.SGD([y_hat], lr=y_step_size0)
    for k in range(K_sample):
        # update step size
        lr = y_step_size0 / (2**k)
        for param_group in optimizer_y.param_groups:
            param_group['lr'] = lr
        
        y_hat_hist = []
        for j in range(2**k):
            optimizer_y.zero_grad()
            logit_y = model(y_hat)
            loss_y = h*m(logit_y - Label_) + b*m(-logit_y + Label_)

            loss_g = -torch.sum(torch.mean(loss_y,1)) + Lambda * torch.sum((torch.norm(y_hat-Feature_,2,1)**2))
            loss_g.backward()
            optimizer_y.step()

            y_hat_hist.append(y_hat.data)
        
        y_hat = torch.mean(torch.stack(y_hat_hist), dim=0)
        if k == K_sample - 2:
            y_K0 = y_hat.clone().detach()
        if k == K_sample - 1:
            y_K10 = y_hat.clone().detach()
    return y_10, y_K0, y_K10

def epoch_SGD_revision(train_data_loader, x_Tr, y_hat0, model, K_sample, Lambda, y_step_size0,
                                        h, b, beta0):
    # epoch SGD step
    #       Input:
    #     data_loader: samples for z_Tr
    #            x_Tr: parameter of regularizaiton at lower level
    #          y_hat0: initial guess
    #           model: NN model
    #        K_sample: outer iteration of E-SGD
    #          Lambda: penalty in objective at inner leve, scalar
    #theta_step_size0: initial step size
    #     (h,b,beta0): parameters within loss function
    m = nn.Softplus(beta=beta0)

    y_hat = Variable(to_tensor(y_hat0), requires_grad=True)
    y_10 = y_hat.detach().clone()
    if K_sample <= 1:
        y_K0 = y_hat.detach().clone()
    
    for k in range(K_sample):
        y_hat_avg = torch.zeros_like(y_hat)
        y_step_sizek = y_step_size0 / (2**(k+1))

        for j in range(2**k):
            z_Tr_j = next(iter(train_data_loader))[0]

            outputs = model(y_hat)
            loss_y = h*m(outputs - z_Tr_j) + b*m(-outputs + z_Tr_j)
            Loss = Lambda * torch.sum((y_hat - x_Tr)**2) - torch.mean(loss_y)
            grad_y = torch.autograd.grad(Loss, y_hat)[0]

            y_hat = y_hat - y_step_sizek * grad_y
            y_hat_avg = y_hat_avg * (1 - 1/(j+1)) + 1/(j+1) * y_hat.detach().clone()
        
        y_hat = Variable(to_tensor(y_hat_avg), requires_grad=True)
    
        if k == K_sample - 2:
            y_K0 = y_hat.detach().clone()
        if k == K_sample - 1:
            y_K10 = y_hat.detach().clone()
    

    y_10 = Variable(to_tensor(y_10), requires_grad=False)
    y_K0 = Variable(to_tensor(y_K0), requires_grad=False)
    y_K10 = Variable(to_tensor(y_K10), requires_grad=False)


    return y_10, y_K0, y_K10

####################################################################
############## Evaluate model performance ##########################
####################################################################

def evaluate(model, x_Te, z_Te, h, b, beta0):
    m = nn.Softplus(beta=beta0)
    model.eval()

    logit_x = model(x_Te)
    #print(logit_x)
    loss_y = h*m(logit_x - z_Te) + b*m(-logit_x + z_Te)
    loss = torch.mean(loss_y)
    return loss.item()

def RTMLMC_obj_oracle(x_Tr, z_Tr, batch_size_z):
    K_sample = int(np.random.choice(list(elements), 1, list(probabilities)))
    x_Tr_size, _ = x_Tr.shape
    idx = np.random.randint(x_Tr_size)
    x_Tr_idx = x_Tr[idx, :]
    z_Tr_idx = z_Tr[idx, :]

    train_data_z_Tr = torch.utils.data.TensorDataset(z_Tr_idx.float())
    train_data_loader_z = torch.utils.data.DataLoader(train_data_z_Tr, batch_size=batch_size_z, shuffle=True)
    y_hat0 = torch.zeros_like(x_Tr_idx)
    y_10, y_K0, y_K10 = epoch_SGD_revision(train_data_loader_z, x_Tr_idx, y_hat0, NN_train,
                                                                K_sample, Lambda, y_step_size0, h, b, beta)
    z_Tr_j = next(iter(train_data_loader_z))[0]
    z_Tr_j = Variable(to_tensor(z_Tr_j), requires_grad=False)

    logit_y10, logit_yK0, logit_yK10 = NN_train(y_10), NN_train(y_K0), NN_train(y_K10)
    loss_y10  = h*m(logit_y10 - z_Tr_j) + b*m(-logit_y10 + z_Tr_j)
    loss_yK0  = h*m(logit_yK0 - z_Tr_j) + b*m(-logit_yK0 + z_Tr_j)
    loss_yK10 = h*m(logit_yK10 - z_Tr_j) + b*m(-logit_yK10 + z_Tr_j)

    loss_adv = torch.mean(loss_y10) + 1/probabilities[K_sample-1] * (torch.mean(loss_yK10) - torch.mean(loss_yK0))
    return loss_adv

def VSGD_obj_oracle(x_Tr, z_Tr, batch_size_z, K_sample):
    x_Tr_size, _ = x_Tr.shape
    idx = np.random.randint(x_Tr_size)
    x_Tr_idx = x_Tr[idx, :]
    z_Tr_idx = z_Tr[idx, :]

    train_data_z_Tr = torch.utils.data.TensorDataset(z_Tr_idx.float())
    train_data_loader_z = torch.utils.data.DataLoader(train_data_z_Tr, batch_size=batch_size_z, shuffle=True)
    y_hat0 = torch.zeros_like(x_Tr_idx)
    y_10, y_K0, y_K10 = epoch_SGD_revision(train_data_loader_z, x_Tr_idx, y_hat0, NN_train,
                                                                K_sample, Lambda, y_step_size0, h, b, beta)
    z_Tr_j = next(iter(train_data_loader_z))[0]
    z_Tr_j = Variable(to_tensor(z_Tr_j), requires_grad=False)

    logit_y = NN_train(y_K10)
    
    loss_y = h*m(logit_y - z_Tr_j) + b*m(-logit_y + z_Tr_j)
    loss_y_summary = torch.mean(loss_y)
    return loss_y_summary   






torch.manual_seed(1103)
torch.cuda.manual_seed(1102)
np.random.seed(1)
input_method = input("Enter method: ")
z_Tr_size    = int(input("Z Tr Size: "))
breakpoint   = int(input("Breakpoint: "))

Lambda = 100
if z_Tr_size >= 50:
    Lambda = 150
h, b, beta = 0.5, 1, 5
m = nn.Softplus(beta=beta)


NN_Ground_Truth = myNN_train()
NN_Ground_Truth.apply(weights_init_uniform_rule)
y_step_size0 = 1e-2
x_Tr_size = 50
x_Tr, z_Tr = np.load("x_Tr.npy"), np.load("z_Tr.npy")
x_Te, z_Te = np.load("x_Te.npy"), np.load("z_Te.npy")
batch_size_z = 50
num_epoch = 100
num_trial = 200


p=0.5
K_sample_max = 6
# sampling from truncated gemoetric distribution
elements = np.arange(K_sample_max)+1
probabilities = p ** (elements-1)
probabilities = probabilities / np.sum(probabilities)

x_Tr_size, _ = x_Tr.shape


acc_test_hist        = []
for trial in range(num_trial):
    torch.manual_seed(1103 + 7*trial)
    torch.cuda.manual_seed(1103 + 7*trial)
    np.random.seed(10 + 7*trial)

    NN_train = myNN_train_test()
    theta_step_size0 = 9e-3
    NN_train.train()
    
    optimizer_theta = torch.optim.SGD(NN_train.parameters(), lr=theta_step_size0)

        
    for epoch in range(num_epoch):
        epoch_start_time = time.time()
        

        if input_method == "VSGD":
            batch_size_inner = 2
            optimizer_theta.zero_grad()
            loss_adv_list = [VSGD_obj_oracle(x_Tr, z_Tr, batch_size_z, K_sample_max) for i in range(batch_size_inner)]
            loss_adv_mean   = torch.mean(torch.stack(loss_adv_list), dim=0)
            loss_adv_mean.backward()
            optimizer_theta.step()
        
        elif input_method == "RTMLMC":
            batch_size_inner = 2
            optimizer_theta.zero_grad()
            loss_adv_list = [RTMLMC_obj_oracle(x_Tr, z_Tr, batch_size_z) for i in range(batch_size_inner)]
            loss_adv_mean   = torch.mean(torch.stack(loss_adv_list), dim=0)
            loss_adv_mean.backward()
            optimizer_theta.step()
        

    
        epoch_end_time  = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
    
        acc_test        = evaluate(NN_train, x_Te, z_Te, h, b, beta)
        if epoch % 49 == 0:
            print('epoch: {:}, acc test: {:.3f}, time: {:.3f}.'.format(epoch, acc_test, per_epoch_ptime))
    acc_test_hist.append(acc_test)



print("-------------------------------")
print(input_method)
print("Lambda: ", Lambda)
print(acc_test_hist)
print(np.mean(np.array(acc_test_hist)))
print(np.std(np.array(acc_test_hist)))





