import os
import sys

import numpy as np
import torch
import torch.optim
import time
# from torchvision import datasets
from sklearn import preprocessing
from sklearn import datasets

from torch.utils.data import DataLoader, Dataset, SequentialSampler
from sklearn import model_selection
from sklearn.datasets import load_svmlight_file
import pickle
sys.path.append('..')
from src.args import parse_args
from src.model import OneHiddenNN, LinearModel
from src.stocCBO import TTSA_denoise_lasso
from src import projection
from src import MyDataset
from torch.nn import functional as F
def out_f(model, images, labels):
    output = model(images)
    params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
    f = F.binary_cross_entropy_with_logits(output,labels, reduction='mean')
    return f

def inner_g(model, weight, images, labels):
    output = model(images)
    a = torch.sigmoid(weight)
    g = torch.mean(a*F.binary_cross_entropy_with_logits(output, labels, reduction='none'))
    return g

def train_model(args, train_x, train_y, test_x, test_y, val_x=None, val_y=None):
    if torch.cuda.is_available():
        device = torch.device('cuda', args.gpu)
    else:
        device = torch.device('cpu')
    # device = torch.device('cpu')

    training_size = train_x.shape[0]
    val_size = val_x.shape[0]
    testing_size = test_x.shape[0]
    d = train_x.shape[1]

    weight = torch.zeros((training_size,1)).to(device)
    # torch.rand()
    weight.requires_grad = True

    #initialize
    per_test = 100
    model = OneHiddenNN(d,1).to(device)

    loss_time_results = np.zeros((int(args.iterations / per_test) + 1, 5))
    t_new = 0
    eps_time = 0
    v=0
    w=0
    eta_x=args.etax
    eta_y=args.etay
    tau=args.tau
    gamma=args.gamma
    alpha=args.alpha
    beta=args.beta
    rw=args.rw
    Q=args.Q
    G_0=1e-6
    eta = args.eta
    delta=1e-6
    N=1
    train_x = train_x.to(device)
    train_y = train_y.to(device)
    val_x = val_x.to(device)
    val_y = val_y.to(device)

    # initialize the model parameters
    for param in model.parameters():
        Pz = torch.ones_like(param)/args.rw
        param.data = Pz
    w2 = 0
    v2 = 0
    for k in range(args.iterations+1):
    
        start_time = time.time()
        etak = 1/np.sqrt(1000+k) 
        alpha = args.alpha*etak
        beta = args.beta*etak
        # etak=0.9
        # alpha = 0.1
        # beta = 0.1
        #approximate the hypergradient
        outer_update = TTSA_denoise_lasso(model, weight, train_x, train_y, val_x, val_y, out_f, inner_g, eta, Q, rw, delta,N)
        g = inner_g(model, weight, train_x, train_y)
        Gy_gradient = torch.autograd.grad(g, model.parameters())[0]
        params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
        with torch.no_grad():
            w = (1 - alpha) * w + alpha * (outer_update +1e-4*weight )
            v = (1 - beta) * v + beta * Gy_gradient
            weight_hat = weight - gamma* w/(torch.sqrt(torch.norm(w))+G_0)
            weight = (1-etak)* weight + etak*weight_hat
            
            Z = param - tau * v / (torch.sqrt(torch.norm(v)) + G_0)
            Pz = projection.projection_l1_ball(Z, rw)
            param_tmp = (1 - etak) * param + etak * Pz
            param.data = param_tmp
        weight.requires_grad = True

        end_time = time.time()
        eps_time += (end_time - start_time)
        
        # if k % 1000 == 0:
        #     tau = tau*0.9
        #     gamma = gamma*0.9

        if k % per_test == 0:

            j = int(k / per_test)
            # evaluate the performance on the training, testing and validation set
            output = model(train_x.to(device))
            preds = output.ge(0.5)
            train_loss_avg = torch.sum(torch.eq(preds.to('cpu'), train_y.to('cpu'))) / training_size
            output = model(test_x.to(device))
            preds = output.ge(0.5)
            test_loss_avg = torch.sum(torch.eq(preds.to('cpu'), test_y.to('cpu'))) / testing_size
            output = model(val_x.to(device))
            preds = output.ge(0.5)
            val_loss_avg = torch.sum(torch.eq(preds.to('cpu'), val_y.to('cpu'))) / val_size
            print('Iteration: {:d} Train Acc: {:.4f} '
                'Validation Acc: {:.4f} Test Acc: {:.4f} Time: {:.4f}'.format(k + 1, train_loss_avg,
                                                                                val_loss_avg,
                                                                                test_loss_avg,
                                                                                eps_time,
                                                                                ))
            loss_time_results[j , 0] = train_loss_avg
            loss_time_results[j , 1] = test_loss_avg
            loss_time_results[j , 2] = val_loss_avg
            loss_time_results[j , 3] = eps_time
            loss_time_results[j , 4] = 0

    # print(loss_time_results)
    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)
    file_name = str(args.num) + '.npy'
    file_addr = os.path.join(args.save_folder, file_name)
    with open(file_addr, 'wb') as f:
        np.save(f, loss_time_results)


def get_data_loaders(args):
    name = args.name
    num = args.num
    data_path = '../datasets/' + name + '/' + name + '.pkl' + str(num)
    with open(data_path, 'rb') as f:
        train_x = pickle.load(f)
        train_y = pickle.load(f)
        test_x = pickle.load(f)
        test_y = pickle.load(f)
        val_x = pickle.load(f)
        val_y = pickle.load(f)
    return train_x, train_y, test_x, test_y, val_x, val_y

def main():
    '''
    we search the regularization parmaeter of the LASSO problem
    :return:
    '''
    args = parse_args()
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # np.random.seed(1)
    # torch.manual_seed(2)
    # torch.cuda.manual_seed(3)
    # load data
    X_tr, y_tr, X_te, y_te, X_val, y_val = get_data_loaders(args)
    train_model(args, X_tr, y_tr, X_te, y_te, X_val, y_val)


if __name__ == '__main__':
    main()
