import argparse
import os
import errno
import random
import time
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data

from load_model_post_stonet import load_data_post_stonet

def dtanh(x):
    return 1-torch.tanh(x).pow(2)


class StoNet(nn.Module):
    def __init__(self, num_hidden, hidden_dim, input_dim, output_dim):
        super(StoNet, self).__init__()
        self.num_hidden = num_hidden

        self.fc = nn.Linear(input_dim, hidden_dim[0])
        self.fc_list = []

        for i in range(num_hidden - 1):
            self.fc_list.append(nn.Linear(hidden_dim[i], hidden_dim[i + 1]))
            self.add_module('fc' + str(i + 2), self.fc_list[-1])
        self.fc_list.append(nn.Linear(hidden_dim[-1], output_dim))
        self.add_module('fc' + str(num_hidden + 1), self.fc_list[-1])

    def forward(self, x):
        x = torch.tanh(self.fc(x))
        for i in range(self.num_hidden - 1):
            x = torch.tanh(self.fc_list[i](x))
        x = self.fc_list[-1](x)
        return x
# Basic Setting
parser = argparse.ArgumentParser(description='Train StoNet Using ASGMCMC')

parser.add_argument('--dataset_name_id', default=0, type=int, help='index of data set')
parser.add_argument('--base_path', default='./result/', type=str,
                    help='base path for saving result')
parser.add_argument('--model_path', default='post_stonet/', type=str, help='folder name for saving model')

parser.add_argument('--lasso', default=0.001, type=float, help='lambda parameter for LASSO')
parser.add_argument('--load_epoch', default=2000, type=int, help='load epoch')

parser.add_argument('--split_start', default=0, type=int, help='split start index')
parser.add_argument('--split_end', default=20, type=int, help='split end index')


args = parser.parse_args()




def main():
    import pickle
    dataset_name_id = args.dataset_name_id
    split_start = args.split_start
    split_end = args.split_end

    dataset_names = ['Wine',
                 'CCPP',
                 'Protein',
                 'Year'
                 ]
    random_state_train_test = np.arange(20)
    num_hidden = 1
    hidden_dim = [20]
    regression_flag = True

    load_epoch = args.load_epoch

    for random_state_train_test_id in range(split_start, split_end):

        dataset_name = dataset_names[dataset_name_id]
        random_state = random_state_train_test[random_state_train_test_id]

        seed = random_state
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        base_path = args.base_path
        model_path = args.model_path
        
        lasso_end = args.lasso
        loss_func = nn.MSELoss()

        x_train_orig, y_train_orig, x_test_orig, y_test_orig = load_data_post_stonet(dataset_name, random_state, base_path, model_path)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        for cross_validate_index in range(6):
            np.random.seed(cross_validate_index)
            torch.manual_seed(cross_validate_index)
            print("cross validation index: {}".format(cross_validate_index))
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(cross_validate_index)

            # 5 fold cross validation and results obtained by training on all data

            if cross_validate_index < 5:
                permutation = np.random.choice(range(x_train_orig.shape[0]), x_train_orig.shape[0], replace=False)
                size_test = np.round(x_train_orig.shape[0] * 0.2).astype(int)
                divid_index = np.arange(x_train_orig.shape[0])
                lower_bound = cross_validate_index * size_test
                upper_bound = (cross_validate_index + 1) * size_test
                test_index = (divid_index >= lower_bound) * (divid_index < upper_bound)

                index_train = permutation[[not _ for _ in test_index]]
                index_test = permutation[test_index]

                x_train = x_train_orig[index_train, :]
                y_train = y_train_orig[index_train]

                x_test = x_train_orig[index_test, :]
                y_test = y_train_orig[index_test]
            else:
                x_train = x_train_orig
                y_train = y_train_orig
                
                x_test = x_test_orig
                y_test = y_test_orig

            print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
            PATH = base_path + dataset_name + '/' + 'data_split_' + str(random_state) + '/' + model_path + 'post_stonet_CI_lasso{:.6f}/'.format(lasso_end) + "cross_validate_{}/".format(cross_validate_index)

            ntrain = x_train.shape[0]
            ntest = x_test.shape[0]
            dim = x_train.shape[1]
            output_dim = 1

            np.random.seed(seed)
            torch.manual_seed(seed)
            net = StoNet(num_hidden, hidden_dim, dim, output_dim)
            net.to(device)

            

            net.load_state_dict(torch.load(PATH + 'model' + str(load_epoch) + '.pt'))

            filename = PATH + 'hidden_state' + str(load_epoch) + '.pt'
            f = open(filename, 'rb')
            [hidden_list, _ ] = pickle.load(f)
            f.close()

            epsilon = 1e-10

            user_mask = {}
            for name, para in net.named_parameters():
                user_mask[name] = para.abs() < epsilon
            ntest = x_test.shape[0]
            lower_bound_list = np.zeros([ntest])
            upper_bound_list = np.zeros([ntest])
            hidden_sigma_2_list = np.zeros([ntest])
            linear_regression_sigma_list = np.zeros([ntest])
            count = 0
            x_prediction = net(x_test)
            for test_index in range(ntest):
                z = x_test[test_index,].unsqueeze(0)
                y_z = y_test[test_index]
                z_score = 1.645
                temp_predict = net.fc(x_train)
                z_var_list = []
                for node_index in range(hidden_dim[0]):
                    node_select = (~user_mask['fc.weight'])[node_index,]
                    temp_input = torch.cat((x_train[:, node_select], torch.ones([x_train.shape[0], 1]).to(device)), -1)
                    input_inverse = temp_input.transpose(0, 1).matmul(temp_input).inverse()
                    temp_z_input = torch.cat((z[:, node_select], torch.ones([z.shape[0], 1]).to(device)), -1)
                    z_var_list.append(temp_z_input.matmul(input_inverse).matmul(temp_z_input.transpose(0, 1)).item())
                z_var_list = torch.FloatTensor(z_var_list).to(device)
                select_hidden = (~user_mask['fc2.weight'][0]) * (~user_mask['fc.weight']).max(1).values
                hidden_mu_1 = net.fc(z).squeeze()[select_hidden]
                # print(z_var_list.shape, temp_predict.shape)
                hidden_sigma_1 = torch.diag((z_var_list * (temp_predict - hidden_list[0]).pow(2).mean(0))[select_hidden])

                hidden_state = net.fc(x_train)
                tanh_impute = torch.tanh(hidden_state)[:, select_hidden]
                tanh_impute = torch.cat((tanh_impute, torch.ones([tanh_impute.shape[0], 1]).to(device)), -1)
                impute_inverse = tanh_impute.transpose(0, 1).matmul(tanh_impute).inverse()
                dtanh_z = torch.diag(dtanh(hidden_mu_1))
                tanh_z = torch.tanh(hidden_mu_1.unsqueeze(0))
                tanh_z = torch.cat((tanh_z, torch.ones([tanh_z.shape[0], 1]).to(device)), -1)
                term1 = impute_inverse[0:-1,][:,0:-1].matmul(dtanh_z).matmul(hidden_sigma_1).matmul(dtanh_z)
                term2 = tanh_z.matmul(impute_inverse).matmul(tanh_z.transpose(0, 1))

                linear_regression_sigma = loss_func(net(x_train), y_train)
                w_1 = net.fc_list[0].weight.data[:, select_hidden]
                term3 = w_1.matmul(dtanh_z).matmul(hidden_sigma_1).matmul(dtanh_z).matmul(w_1.transpose(0, 1))
                hidden_sigma_2 = (term1.trace() + term2) * linear_regression_sigma + term3
                hidden_mu_2 = net(z)

                lower_bound = hidden_mu_2 - (hidden_sigma_2 + linear_regression_sigma).sqrt() * z_score
                upper_bound = hidden_mu_2 + (hidden_sigma_2 + linear_regression_sigma).sqrt() * z_score

                if lower_bound < y_z and upper_bound > y_z:
                    count += 1

                lower_bound_list[test_index] = lower_bound
                upper_bound_list[test_index] = upper_bound

                hidden_sigma_2_list[test_index] = hidden_sigma_2
                linear_regression_sigma_list[test_index] = linear_regression_sigma
                if test_index % 100 == 0:
                    print('test index = ', test_index)
                    print('count = ', count)
                    print('cover ratio = ', 1.0 * count / (test_index + 1))
                    print('ave interval = ', (upper_bound_list[0:test_index] - lower_bound_list[0:test_index]).mean())
                    filename = PATH + 'CI.txt'
                    f = open(filename, 'wb')
                    pickle.dump([lower_bound_list, upper_bound_list, hidden_sigma_2_list, linear_regression_sigma_list, x_test, x_prediction, y_test, count], f)
                    f.close()

            filename = PATH + 'CI.txt'
            f = open(filename, 'wb')
            pickle.dump([lower_bound_list, upper_bound_list, hidden_sigma_2_list, linear_regression_sigma_list, x_test, x_prediction, y_test, count], f)
            f.close()


if __name__ == '__main__':
    main()
