import argparse
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import resnet
import os
import errno
from torch.utils.data.sampler import SubsetRandomSampler

from process_data import preprocess_data

# Basic Setting
parser = argparse.ArgumentParser(description='CIFAR10 Post StoNet')
parser.add_argument('--num_seed', default=10, type=int, help='set number of seed')
parser.add_argument('--data_name', default='CIFAR10_resnet', type=str, help='specify the name of the data')
parser.add_argument('--base_path', default='./result/CIFAR10/', type=str,
                    help='base path for saving result')

parser.add_argument('--load_model_path', default='test_run/', type=str, help='folder name for loading saved model')                    

parser.add_argument('--model_path', default='post_stonet/', type=str, help='folder name for saving model')
parser.add_argument('--regression_flag', default=False, type=int,
                    help='true for regression and false for classification')
parser.add_argument('--confidence_interval_flag', default=False, type=int,
                    help='whether to store result to compute confidence interval')

# model
parser.add_argument('--layer', default=1, type=int, help='number of hidden layer')
parser.add_argument('--unit', default=[100], type=int, nargs='+', help='number of hidden unit in each layer')
parser.add_argument('--sigma', default=[0.001, 0.0001], type=float, nargs='+',
                    help='variance of each layer for the hidden variable model')

# Training Setting
parser.add_argument('--nepoch', default=1000, type=int, help='total number of training epochs')
parser.add_argument('--save_interval', default=1, type=int, help='epoch interval for model saving')
parser.add_argument('--batch_train', default = 128, type = int, help = 'batch size for training')

parser.add_argument('--MH_step', default=1, type=int, help='number of SGMCMC step for imputation')
parser.add_argument('--impute_lr', default=[0.0000001], type=float, nargs='+', help='step size in imputation')
parser.add_argument('--impute_alpha', default=1, type=float, help='momentum parameter for HMC')
parser.add_argument('--temperature', default=[0.1], type=float, nargs='+', help='temperature parameter for HMC')

parser.add_argument('--para_update_step', default=1, type=int, help='number of parameter update step in each iteration')
parser.add_argument('--para_lr', default=[0.0005, 0.000005], type=float, nargs='+', help='step size in parameter update')
parser.add_argument('--para_momentum', default=0.9, type=float, help='momentum parameter for parameter update')

parser.add_argument('--lasso', default=0.000001, type=float, help='lambda parameter for LASSO')
parser.add_argument('--lasso_anneal_start', default=0, type=int, help='lambda parameter for LASSO')
parser.add_argument('--lasso_anneal_end', default=0, type=int, help='lambda parameter for LASSO')


args = parser.parse_args()




class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

class Net(nn.Module):
    def __init__(self, num_hidden, hidden_dim, input_dim, output_dim):
        super(Net, self).__init__()
        self.num_hidden = num_hidden

        self.fc = nn.Linear(input_dim, hidden_dim[0])
        self.fc_list = []

        for i in range(num_hidden - 1):
            self.fc_list.append(nn.Linear(hidden_dim[i], hidden_dim[i + 1]))
            self.add_module('fc' + str(i + 2), self.fc_list[-1])
        self.fc_list.append(nn.Linear(hidden_dim[-1], output_dim))
        self.add_module('fc' + str(num_hidden + 1), self.fc_list[-1])

    def forward(self, x):
        x = torch.tanh(self.fc(x))
        for i in range(self.num_hidden - 1):
            x = torch.tanh(self.fc_list[i](x))
        x = self.fc_list[-1](x)
        return x




def main():
    import pickle
    # load hyper-parameter
    num_seed = args.num_seed
    save_interval = args.save_interval
    base_path = args.base_path
    load_model_path = args.load_model_path


    for seed in range(1, num_seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        data_name = args.data_name
        num_hidden = args.layer
        hidden_dim = args.unit
        regression_flag = args.regression_flag
        num_epochs = args.nepoch

        if data_name == 'CIFAR10_densenet':
            load_PATH = '{}{}/densenet_40/seed{}/'.format(base_path, load_model_path, seed)
        if data_name == 'CIFAR10_wr':
            load_PATH = '{}{}/wr_28_10/seed{}/'.format(base_path, load_model_path, seed)
        if data_name == 'CIFAR10_resnet':
            load_PATH = '{}{}/resnet_110/seed{}/'.format(base_path, load_model_path, seed)

        PATH = load_PATH + args.model_path

        proposal_lr = args.impute_lr
        sigma_list = args.sigma
        temperature = args.temperature
        para_lr = args.para_lr
        para_momentum = args.para_momentum
        subn = args.batch_train
        MH_step = args.MH_step
        alpha = args.impute_alpha
        num_para_update_step = args.para_update_step
        confidence_interval_flag = args.confidence_interval_flag

        lasso_end = args.lasso
        lasso_anneal_start = args.lasso_anneal_start
        lasso_anneal_end = args.lasso_anneal_end

        # load data
        x_train, y_train, x_test, y_test = preprocess_data(data_name, base_path, load_model_path, seed)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        ntrain = x_train.shape[0]
        ntest = x_test.shape[0]
        dim = x_train.shape[1]

        # define loss function
        sse = nn.MSELoss(reduction='sum')
        if regression_flag:
            output_dim = 1
            loss_func = nn.MSELoss()
            loss_func_sum = nn.MSELoss(reduction='sum')
            train_loss_path = np.zeros(num_epochs)
            test_loss_path = np.zeros(num_epochs)
        else:
            output_dim = int((y_test.max() + 1).item())
            ece_loss_func = _ECELoss().cuda()
            loss_func = nn.CrossEntropyLoss()
            loss_func_sum = nn.CrossEntropyLoss(reduction='sum')
            train_loss_path = np.zeros(num_epochs)
            test_loss_path = np.zeros(num_epochs)
            train_ece_loss_path = np.zeros(num_epochs)
            test_ece_loss_path = np.zeros(num_epochs)
            train_accuracy_path = np.zeros(num_epochs)
            test_accuracy_path = np.zeros(num_epochs)
        time_used_path = np.zeros(num_epochs)

        # define model
        np.random.seed(seed)
        torch.manual_seed(seed)
        net = Net(num_hidden, hidden_dim, dim, output_dim)
        net.to(device)

        # path to save result
        if not os.path.isdir(PATH):
            try:
                os.makedirs(PATH)
            except OSError as exc:  # Python >2.5
                if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                    pass
                else:
                    raise

        # define hyper-parameter for imputation
        if len(proposal_lr) == 1 and num_hidden > 1:
            temp_proposal_lr = proposal_lr[0]
            proposal_lr = []
            for i in range(num_hidden):
                proposal_lr.append(temp_proposal_lr)

        if len(sigma_list) == 1 and num_hidden >= 1:
            temp_sigma_list = sigma_list[0]
            sigma_list = []
            for i in range(num_hidden + 1):
                sigma_list.append(temp_sigma_list)

        if len(temperature) == 1 and num_hidden > 1:
            temp_temperature = temperature[0]
            temperature = []
            for i in range(num_hidden):
                temperature.append(temp_temperature)

        if len(para_lr) == 1 and num_hidden >= 1:
            temp_para_lr = para_lr[0]
            para_lr = []
            for i in range(num_hidden + 1):
                para_lr.append(temp_para_lr)

        # define hyper-parameter for parameter update
        optimizer_list = []
        optimizer_list.append(torch.optim.SGD(net.fc.parameters(), lr=para_lr[0], momentum=para_momentum))
        for i in range(num_hidden):
            optimizer_list.append(torch.optim.SGD(net.fc_list[i].parameters(), lr=para_lr[i+1], momentum=para_momentum))


        # training
        index = np.arange(ntrain)
        for epoch in range(num_epochs):
            start_time = time.process_time()
            np.random.shuffle(index)

            if epoch < lasso_anneal_start:
                lasso_lambda = 0.00001
            elif epoch >= lasso_anneal_start and epoch < lasso_anneal_end:
                lasso_lambda = 0.00001 * (lasso_anneal_end - epoch) / (lasso_anneal_end - lasso_anneal_start) + lasso_end * (epoch - lasso_anneal_start) / (lasso_anneal_end - lasso_anneal_start)
            else:
                lasso_lambda = lasso_end

            for iter_index in range(ntrain // subn):
                subsample = index[(iter_index * subn):((iter_index + 1) * subn)]

                # Initialize hidden units by forward pass
                hidden_list = []
                momentum_list = []
                with torch.no_grad():
                    hidden_list.append(net.fc(x_train[subsample, ]))
                    momentum_list.append(torch.zeros_like(hidden_list[-1]))
                    for i in range(num_hidden - 1):
                        hidden_list.append(net.fc_list[i](torch.tanh(hidden_list[-1])))
                        momentum_list.append(torch.zeros_like(hidden_list[-1]))
                for i in range(hidden_list.__len__()):
                    hidden_list[i].requires_grad = True
                with torch.no_grad():
                    forward_hidden = torch.clone(hidden_list[0])

                # backward imputation by SGHMC, When alpha = 1, it becomes SGLD
                for repeat in range(MH_step):
                    for layer_index in reversed(range(num_hidden)):
                        if hidden_list[layer_index].grad is not None:
                            hidden_list[layer_index].grad.zero_()
                        if layer_index == num_hidden - 1:
                            hidden_likelihood = -loss_func_sum(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                               y_train[subsample, ]) / sigma_list[layer_index + 1]
                        else:
                            hidden_likelihood = -sse(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                     hidden_list[layer_index + 1]) / sigma_list[layer_index + 1]
                        if layer_index == 0:
                            hidden_likelihood = hidden_likelihood - sse(forward_hidden, hidden_list[layer_index]) / sigma_list[layer_index]
                        else:
                            hidden_likelihood = hidden_likelihood - sse(
                                net.fc_list[layer_index - 1](torch.tanh(hidden_list[layer_index - 1])),
                                hidden_list[layer_index]) / sigma_list[layer_index]
                        hidden_likelihood.backward()
                        step_proposal_lr = proposal_lr[layer_index]
                        with torch.no_grad():
                            momentum_list[layer_index] = (1 - alpha) * momentum_list[layer_index] + step_proposal_lr / 2 * \
                                                         hidden_list[
                                                             layer_index].grad + torch.FloatTensor(
                                hidden_list[layer_index].shape).to(device).normal_().mul(
                                np.sqrt(alpha * step_proposal_lr * temperature[layer_index]))
                            hidden_list[layer_index].data += momentum_list[layer_index]


                for step in range(num_para_update_step):
                    # update parameter for first layer
                    loss = sse(net.fc(x_train[subsample, ]), hidden_list[0]) / sigma_list[0] / subn

                    for para in net.fc.parameters():
                        loss = loss + lasso_lambda * para.abs().sum() / sigma_list[-1] * (para_lr[-1] / para_lr[0])

                    optimizer_list[0].zero_grad()
                    loss.backward()
                    optimizer_list[0].step()


                    for layer_index in range(num_hidden):
                        # update parameters layer by layer
                        if layer_index == num_hidden - 1:
                            loss = loss_func_sum(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                               y_train[subsample, ]) / sigma_list[layer_index + 1] / subn
                        else:
                            loss = sse(net.fc_list[layer_index](torch.tanh(hidden_list[layer_index])),
                                                     hidden_list[layer_index + 1]) / sigma_list[layer_index + 1] / subn

                        for para in net.fc_list[layer_index].parameters():
                            loss = loss + lasso_lambda * para.abs().sum() / sigma_list[-1] * (para_lr[-1] / para_lr[layer_index + 1])

                        optimizer_list[layer_index + 1].zero_grad()
                        loss.backward()

                        optimizer_list[layer_index + 1].step()


            # print and save result at the end of each epoch
            with torch.no_grad():
                if regression_flag:
                    print('epoch: ', epoch)

                    output = net(x_train)
                    train_loss = loss_func(output, y_train)
                    train_loss_path[epoch] = train_loss
                    print("train loss: ", train_loss)

                    output = net(x_test)
                    test_loss = loss_func(output, y_test)
                    test_loss_path[epoch] = test_loss
                    print("test loss: ", test_loss)

                else:
                    print('epoch: ', epoch)

                    output = net(x_train)
                    train_loss = loss_func(output, y_train)
                    prediction = output.data.max(1)[1]
                    train_accuracy = prediction.eq(y_train.data).sum().item() / ntrain

                    train_ece_loss = ece_loss_func(output, y_train)
                    train_ece_loss_path[epoch] = train_ece_loss

                    train_loss_path[epoch] = train_loss
                    train_accuracy_path[epoch] = train_accuracy
                    print("train loss: ", train_loss, "train ece loss: ", train_ece_loss, 'train accuracy: ', train_accuracy)

                    output = net(x_test)
                    test_loss = loss_func(output, y_test)
                    prediction = output.data.max(1)[1]
                    test_accuracy = prediction.eq(y_test.data).sum().item() / ntest

                    test_ece_loss = ece_loss_func(output, y_test)
                    test_ece_loss_path[epoch] = test_ece_loss


                    test_loss_path[epoch] = test_loss
                    test_accuracy_path[epoch] = test_accuracy
                    print("test loss: ", test_loss, "test ece loss: ", test_ece_loss,'test accuracy: ', test_accuracy)

                para = net.fc.weight
                threshold = lasso_lambda
                num_feature = np.sum(np.max(1 - (para.abs() < threshold).data.cpu().numpy(), 0) > 0)
                print('number of selected:', num_feature)

            if epoch % save_interval == 0:
                torch.save(net.state_dict(), PATH + 'model' + str(epoch) + '.pt')
                if confidence_interval_flag:
                    filename = PATH + 'hidden_state' + str(epoch) + '.pt'
                    f = open(filename, 'wb')
                    pickle.dump([hidden_list, x_train[subsample, ]], f, protocol=4)
                    f.close()

            end_time = time.process_time()

            time_used_path[epoch] = end_time - start_time

        if regression_flag:
            filename = PATH + 'result.txt'
            f = open(filename, 'wb')
            pickle.dump([train_loss_path, test_loss_path, time_used_path], f)
            f.close()
        else:
            filename = PATH + 'result.txt'
            f = open(filename, 'wb')
            pickle.dump([train_loss_path, test_loss_path, train_ece_loss_path, test_ece_loss_path, train_accuracy_path, test_accuracy_path, time_used_path], f)
            f.close()
        if confidence_interval_flag:
            filename = PATH + 'data.txt'
            f = open(filename, 'wb')
            pickle.dump(
                [x_train, x_test, y_train, y_test], f)
            f.close()


if __name__ == '__main__':
    main()




