import argparse
import os
import errno
import time
import random
import pickle

import numpy as np

import torch
import torch.nn as nn
import torch.utils.data

from sklearn.linear_model import Lasso
from sklearn.linear_model import LinearRegression
from load_model_post_stonet import load_data_post_stonet

# Basic Setting
parser = argparse.ArgumentParser(description='Train StoNet Using ASGMCMC')

parser.add_argument('--dataset_name_id', default=0, type=int, help='index of data set')
parser.add_argument('--base_dataset_path', default='./data/', type=str,
                    help='folder path of data sets')
parser.add_argument('--base_path', default='./result/', type=str,
                    help='base path for saving result')
parser.add_argument('--model_path', default='post_stonet/', type=str, help='folder name for saving model')
parser.add_argument('--regression_flag', default=True, type=int,
                    help='true for regression and false for classification')
parser.add_argument('--confidence_interval_flag', default=True, type=int,
                    help='whether to store result to compute confidence interval')


# model
parser.add_argument('--layer', default=1, type=int, help='number of hidden layer')
parser.add_argument('--unit', default=[20], type=int, nargs='+', help='number of hidden unit in each layer')
parser.add_argument('--sigma', default=[0.001, 0.0001], type=float, nargs='+',
                    help='variance of each layer for the hidden variable model')

# Training Setting
parser.add_argument('--nepoch', default=1001, type=int, help='total number of training epochs')
parser.add_argument('--save_interval', default=100, type=int, help='epoch interval for model saving')
parser.add_argument('--batch_train', default = 50, type = int, help = 'batch size for training')

parser.add_argument('--MH_step', default=1, type=int, help='number of SGMCMC step for imputation')
parser.add_argument('--impute_lr', default=[0.000001], type=float, nargs='+', help='step size in imputation')
parser.add_argument('--impute_alpha', default=1, type=float, help='momentum parameter for HMC')
parser.add_argument('--temperature', default=[0.1], type=float, nargs='+', help='temperature parameter for HMC')

parser.add_argument('--para_update_step', default=1, type=int, help='number of parameter update step in each iteration')
parser.add_argument('--para_lr', default=[0.001, 0.000001], type=float, nargs='+', help='step size in parameter update')
parser.add_argument('--para_momentum', default=0.9, type=float, help='momentum parameter for parameter update')

parser.add_argument('--lasso', default=0.001, type=float, help='lambda parameter for LASSO')
parser.add_argument('--lasso_anneal_start', default=0, type=int, help='lambda parameter for LASSO')
parser.add_argument('--lasso_anneal_end', default=0, type=int, help='lambda parameter for LASSO')

parser.add_argument('--split_start', default=0, type=int, help='split start index')
parser.add_argument('--split_end', default=20, type=int, help='split end index')


args = parser.parse_args()

class StoNet(nn.Module):
    def __init__(self, num_hidden, hidden_dim, input_dim, output_dim):
        super(StoNet, self).__init__()
        self.num_hidden = num_hidden

        self.fc = nn.Linear(input_dim, hidden_dim[0])
        self.fc_list = []

        for i in range(num_hidden - 1):
            self.fc_list.append(nn.Linear(hidden_dim[i], hidden_dim[i + 1]))
            self.add_module('fc' + str(i + 2), self.fc_list[-1])
        self.fc_list.append(nn.Linear(hidden_dim[-1], output_dim))
        self.add_module('fc' + str(num_hidden + 1), self.fc_list[-1])

    def forward(self, x):
        x = torch.tanh(self.fc(x))
        for i in range(self.num_hidden - 1):
            x = torch.tanh(self.fc_list[i](x))
        x = self.fc_list[-1](x)
        return x




def main():
    import pickle
    save_interval = args.save_interval

    dataset_name_id = args.dataset_name_id
    
    split_start = args.split_start
    split_end = args.split_end

    dataset_names = ['Wine',
                 'CCPP',
                 'Protein',
                 'Year'
                 ]
    base_dataset_path = args.base_dataset_path    
    random_state_train_test = np.arange(20)

    for random_state_train_test_id in range(split_start, split_end):
        dataset_name = dataset_names[dataset_name_id]
        random_state = random_state_train_test[random_state_train_test_id]

        seed = random_state
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            

        num_hidden = args.layer
        hidden_dim = args.unit
        regression_flag = args.regression_flag
        num_epochs = args.nepoch

        base_path = args.base_path
        model_path = args.model_path
        

        proposal_lr = args.impute_lr
        sigma_list = args.sigma
        temperature = args.temperature
        para_lr = args.para_lr
        para_momentum = args.para_momentum
        subn = args.batch_train
        MH_step = args.MH_step
        alpha = args.impute_alpha
        num_para_update_step = args.para_update_step
        confidence_interval_flag = args.confidence_interval_flag
        
        # load data
        x_train_orig, y_train_orig, x_test_orig, y_test_orig = load_data_post_stonet(dataset_name, random_state, base_path, model_path)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        for cross_validate_index in range(6):
            np.random.seed(cross_validate_index)
            torch.manual_seed(cross_validate_index)
            print("cross validation index: {}".format(cross_validate_index))
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(cross_validate_index)

            if cross_validate_index < 5:
                permutation = np.random.choice(range(x_train_orig.shape[0]), x_train_orig.shape[0], replace=False)
                size_test = np.round(x_train_orig.shape[0] * 0.2).astype(int)
                divid_index = np.arange(x_train_orig.shape[0])
                lower_bound = cross_validate_index * size_test
                upper_bound = (cross_validate_index + 1) * size_test
                test_index = (divid_index >= lower_bound) * (divid_index < upper_bound)

                index_train = permutation[[not _ for _ in test_index]]
                index_test = permutation[test_index]

                x_train = x_train_orig[index_train, :]
                y_train = y_train_orig[index_train]

                x_test = x_train_orig[index_test, :]
                y_test = y_train_orig[index_test]
            else:
                x_train = x_train_orig
                y_train = y_train_orig
                
                x_test = x_test_orig
                y_test = y_test_orig
            lasso_end = args.lasso
            PATH = base_path + dataset_name + '/' + 'data_split_' + str(random_state) + '/' + model_path + 'post_stonet_CI_lasso{:.6f}/'.format(lasso_end) + "cross_validate_{}/".format(cross_validate_index)


            lasso_anneal_start = args.lasso_anneal_start
            lasso_anneal_end = args.lasso_anneal_end

            ntrain = x_train.shape[0]
            ntest = x_test.shape[0]
            dim = x_train.shape[1]

            if subn > ntrain:
                subn = ntrain

            print("num train: {}, num test: {}".format(ntrain, ntest))

            # define loss function
            sse = nn.MSELoss(reduction='sum')
            if regression_flag:
                output_dim = 1
                loss_func = nn.MSELoss()
                loss_func_sum = nn.MSELoss(reduction='sum')
                train_loss_path = np.zeros(num_epochs)
                test_loss_path = np.zeros(num_epochs)
            else:
                output_dim = int((y_test.max() + 1).item())
                ece_loss_func = _ECELoss().cuda()
                loss_func = nn.CrossEntropyLoss()
                loss_func_sum = nn.CrossEntropyLoss(reduction='sum')
                train_loss_path = np.zeros(num_epochs)
                test_loss_path = np.zeros(num_epochs)
                train_ece_loss_path = np.zeros(num_epochs)
                test_ece_loss_path = np.zeros(num_epochs)
                train_accuracy_path = np.zeros(num_epochs)
                test_accuracy_path = np.zeros(num_epochs)
            time_used_path = np.zeros(num_epochs)

            # define model
            np.random.seed(seed)
            torch.manual_seed(seed)
            net = StoNet(num_hidden, hidden_dim, dim, output_dim)
            net.to(device)

            # path to save result
            if not os.path.isdir(PATH):
                try:
                    os.makedirs(PATH)
                except OSError as exc:  # Python >2.5
                    if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                        pass
                    else:
                        raise

            # define hyper-parameter for imputation
            if len(proposal_lr) == 1 and num_hidden > 1:
                temp_proposal_lr = proposal_lr[0]
                proposal_lr = []
                for i in range(num_hidden):
                    proposal_lr.append(temp_proposal_lr)

            if len(sigma_list) == 1 and num_hidden >= 1:
                temp_sigma_list = sigma_list[0]
                sigma_list = []
                for i in range(num_hidden + 1):
                    sigma_list.append(temp_sigma_list)

            if len(temperature) == 1 and num_hidden > 1:
                temp_temperature = temperature[0]
                temperature = []
                for i in range(num_hidden):
                    temperature.append(temp_temperature)

            if len(para_lr) == 1 and num_hidden >= 1:
                temp_para_lr = para_lr[0]
                para_lr = []
                for i in range(num_hidden + 1):
                    para_lr.append(temp_para_lr)

            # define hyper-parameter for parameter update
            optimizer_list = []
            optimizer_list.append(torch.optim.SGD(net.fc.parameters(), lr=para_lr[0], momentum=para_momentum))
            for i in range(num_hidden):
                optimizer_list.append(torch.optim.SGD(net.fc_list[i].parameters(), lr=para_lr[i+1], momentum=para_momentum))

            # training
            index = np.arange(ntrain)
            for epoch in range(num_epochs):
                start_time = time.process_time()
                np.random.shuffle(index)
                if epoch == num_epochs - 1:
                    index = np.arange(ntrain)
                    subn = ntrain

                if epoch < lasso_anneal_start:
                    lasso_lambda = 0.00001
                elif epoch >= lasso_anneal_start and epoch < lasso_anneal_end:
                    lasso_lambda = 0.00001 * (lasso_anneal_end - epoch) / (lasso_anneal_end - lasso_anneal_start) + lasso_end * (epoch - lasso_anneal_start) / (lasso_anneal_end - lasso_anneal_start)
                else:
                    lasso_lambda = lasso_end

                for iter_index in range(ntrain // subn):
                    subsample = index[(iter_index * subn):((iter_index + 1) * subn)]

                    # Initialize hidden units by forward pass
                    hidden_list = []
                    momentum_list = []
                    with torch.no_grad():
                        hidden_list.append(net.fc(x_train[subsample, ]))
                        momentum_list.append(torch.zeros_like(hidden_list[-1]))
                        for i in range(num_hidden - 1):
                            hidden_list.append(net.fc_list[i](torch.tanh(hidden_list[-1])))
                            momentum_list.append(torch.zeros_like(hidden_list[-1]))
                    for i in range(hidden_list.__len__()):
                        hidden_list[i].requires_grad = True
                    with torch.no_grad():
                        forward_hidden = torch.clone(hidden_list[0])

                    # backward imputation by SGHMC, When alpha = 1, it becomes SGLD
                    for repeat in range(MH_step):
                        for layer_index in reversed(range(num_hidden)):
                            if hidden_list[layer_index].grad is not None:
                                hidden_list[layer_index].grad.zero_()
                            if layer_index == num_hidden - 1:
                                hidden_likelihood = -loss_func_sum(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                                   y_train[subsample, ]) / sigma_list[layer_index + 1]
                            else:
                                hidden_likelihood = -sse(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                         hidden_list[layer_index + 1]) / sigma_list[layer_index + 1]
                            if layer_index == 0:
                                hidden_likelihood = hidden_likelihood - sse(forward_hidden, hidden_list[layer_index]) / sigma_list[layer_index]
                            else:
                                hidden_likelihood = hidden_likelihood - sse(
                                    net.fc_list[layer_index - 1](torch.tanh(hidden_list[layer_index - 1])),
                                    hidden_list[layer_index]) / sigma_list[layer_index]
                            hidden_likelihood.backward()
                            step_proposal_lr = proposal_lr[layer_index]
                            with torch.no_grad():
                                momentum_list[layer_index] = (1 - alpha) * momentum_list[layer_index] + step_proposal_lr / 2 * \
                                                             hidden_list[
                                                                 layer_index].grad + torch.FloatTensor(
                                    hidden_list[layer_index].shape).to(device).normal_().mul(
                                    np.sqrt(alpha * step_proposal_lr * temperature[layer_index]))
                                hidden_list[layer_index].data += momentum_list[layer_index]

                    if epoch == num_epochs - 1:
                        for layer_index in range(num_hidden + 1):
                            if layer_index == num_hidden:
                                num_node = output_dim
                            else:
                                num_node = hidden_dim[layer_index]
                            for node_index in range(num_node):
                                print(node_index)
                                if layer_index == 0:
                                    if lasso_lambda != 0.0:
                                        temp_lasso_predictor = Lasso(alpha=lasso_lambda, random_state=0, max_iter=999999999).fit(
                                                    x_train.data.cpu(), hidden_list[layer_index][:, node_index].data.cpu())
                                        net.fc.weight.data[node_index, :] = torch.FloatTensor(temp_lasso_predictor.coef_)
                                        net.fc.bias.data[node_index] = float(temp_lasso_predictor.intercept_)
                                    else:
                                        temp_lasso_predictor = LinearRegression().fit(
                                                    x_train.data.cpu(), hidden_list[layer_index][:, node_index].data.cpu())
                                        net.fc.weight.data[node_index, :] = torch.FloatTensor(temp_lasso_predictor.coef_)
                                        net.fc.bias.data[node_index] = float(temp_lasso_predictor.intercept_)
                                elif layer_index == num_hidden:
                                    if lasso_lambda != 0.0:
                                        temp_lasso_predictor = Lasso(alpha=lasso_lambda, random_state=0, max_iter=999999999).fit(
                                                    torch.tanh(hidden_list[layer_index - 1]).cpu().detach(), y_train[:, node_index].data.cpu())
                                        net.fc_list[layer_index - 1].weight.data[node_index, :] = torch.FloatTensor(temp_lasso_predictor.coef_)
                                        net.fc_list[layer_index - 1].bias.data[node_index] = float(temp_lasso_predictor.intercept_)
                                    else:
                                        temp_lasso_predictor = LinearRegression().fit(
                                                    torch.tanh(hidden_list[layer_index - 1]).cpu().detach(), y_train[:, node_index].data.cpu())

                                        net.fc_list[layer_index - 1].weight.data[node_index, :] = torch.FloatTensor(temp_lasso_predictor.coef_)
                                        net.fc_list[layer_index - 1].bias.data[node_index] = float(temp_lasso_predictor.intercept_)

                                else:
                                    if lasso_lambda != 0.0:
                                        temp_lasso_predictor = Lasso(alpha=lasso_lambda, random_state=0, max_iter=999999999).fit(
                                                    torch.tanh(hidden_list[layer_index - 1]).cpu().detach(), hidden_list[layer_index][:, node_index].data.cpu())
                                        net.fc_list[layer_index - 1].weight.data[node_index, :] = torch.FloatTensor(temp_lasso_predictor.coef_)
                                        net.fc_list[layer_index - 1].bias.data[node_index] = float(temp_lasso_predictor.intercept_)

                                    else:
                                        temp_lasso_predictor = LinearRegression().fit(
                                                    torch.tanh(hidden_list[layer_index - 1]).cpu().detach(), hidden_list[layer_index][:, node_index].data.cpu())
                                        net.fc_list[layer_index - 1].weight.data[node_index, :] = torch.FloatTensor(temp_lasso_predictor.coef_)
                                        net.fc_list[layer_index - 1].bias.data[node_index] = float(temp_lasso_predictor.intercept_)

                    else:
                        for step in range(num_para_update_step):
                            loss = sse(net.fc(x_train[subsample, ]), hidden_list[0]) / sigma_list[0] / subn
                            for para in net.fc.parameters():
                                loss = loss + lasso_lambda * para.abs().sum() / sigma_list[-1] * (para_lr[-1] / para_lr[0])
                            optimizer_list[0].zero_grad()
                            loss.backward()
                            optimizer_list[0].step()

                            for layer_index in range(num_hidden):
                                # update parameters layer by layer
                                if layer_index == num_hidden - 1:
                                    loss = loss_func_sum(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                                    y_train[subsample, ]) / sigma_list[layer_index + 1] / subn
                                else:
                                    loss = sse(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                            hidden_list[layer_index + 1]) / sigma_list[layer_index + 1] / subn

                                for para in net.fc_list[layer_index].parameters():
                                    loss = loss + lasso_lambda * para.abs().sum() / sigma_list[-1] * (para_lr[-1] / para_lr[layer_index + 1])

                                optimizer_list[layer_index + 1].zero_grad()
                                loss.backward()
                                optimizer_list[layer_index + 1].step()


                # print and save result at the end of each epoch
                with torch.no_grad():
                    if regression_flag:
                        print('epoch: ', epoch)

                        output = net(x_train)
                        train_loss = loss_func(output, y_train)
                        train_loss_path[epoch] = train_loss
                        print("train loss: ", train_loss)

                        output = net(x_test)
                        test_loss = loss_func(output, y_test)
                        test_loss_path[epoch] = test_loss
                        print("test loss: ", test_loss)

                    else:
                        print('epoch: ', epoch)

                        output = net(x_train)
                        train_loss = loss_func(output, y_train)
                        prediction = output.data.max(1)[1]
                        train_accuracy = prediction.eq(y_train.data).sum().item() / ntrain

                        train_ece_loss = ece_loss_func(output, y_train)
                        train_ece_loss_path[epoch] = train_ece_loss

                        train_loss_path[epoch] = train_loss
                        train_accuracy_path[epoch] = train_accuracy
                        print("train loss: ", train_loss, "train ece loss: ", train_ece_loss, 'train accuracy: ', train_accuracy)

                        output = net(x_test)
                        test_loss = loss_func(output, y_test)
                        prediction = output.data.max(1)[1]
                        test_accuracy = prediction.eq(y_test.data).sum().item() / ntest

                        test_ece_loss = ece_loss_func(output, y_test)
                        test_ece_loss_path[epoch] = test_ece_loss


                        test_loss_path[epoch] = test_loss
                        test_accuracy_path[epoch] = test_accuracy
                        print("test loss: ", test_loss, "test ece loss: ", test_ece_loss,'test accuracy: ', test_accuracy)

                    para = net.fc.weight
                    threshold = lasso_lambda
                    num_feature = np.sum(np.max(1 - (para.abs() < threshold).data.cpu().numpy(), 0) > 0)
                    print('number of selected:', num_feature)

                if epoch % save_interval == 0:
                    torch.save(net.state_dict(), PATH + 'model' + str(epoch) + '.pt')
                    if confidence_interval_flag:
                        filename = PATH + 'hidden_state' + str(epoch) + '.pt'
                        f = open(filename, 'wb')
                        pickle.dump([hidden_list, x_train[subsample, ]], f, protocol=4)
                        f.close()

                end_time = time.process_time()

                time_used_path[epoch] = end_time - start_time

            if regression_flag:
                filename = PATH + 'result.txt'
                f = open(filename, 'wb')
                pickle.dump([train_loss_path, test_loss_path, time_used_path], f)
                f.close()
            else:
                filename = PATH + 'result.txt'
                f = open(filename, 'wb')
                pickle.dump([train_loss_path, test_loss_path, train_ece_loss_path, test_ece_loss_path, train_accuracy_path, test_accuracy_path, time_used_path], f)
                f.close()
            if confidence_interval_flag:
                filename = PATH + 'data.txt'
                f = open(filename, 'wb')
                pickle.dump(
                    [x_train, x_test, y_train, y_test], f)
                f.close()


if __name__ == '__main__':
    main()




