import os

import time
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, RMSprop
from torch.utils.data import DataLoader, TensorDataset

from samplers.EWSG import EWSG
from samplers.SGHMC import SGHMC
from samplers.SGLD import SGLD
from samplers.pSGLD import pSGLD
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--method', help='Sampler to use.')
parser.add_argument("--seed", type=int, help='Seed')
parser.add_argument('--epoch', type=int, default=1, help='Number of data passes.')
parser.add_argument("--h", type=float, help="Step size to use.")
parser.add_argument("--minibatch_size", default=50, type=int, help="Minibatch size.")
parser.add_argument("--gamma", type=float, default=50.0, help="gamma")
parser.add_argument("--n_samples", type=int, default=20, help="Number of ensembles.")
parser.add_argument("--reg", type=float, default=0.1, help="Regularization strength.")
parser.add_argument("--M", type=int, default=1, help="The length of index chain in EWSG.")
parser.add_argument("--input_path", help="Path to training and test data set.")
parser.add_argument("--output_path", help="Path to logger.")

args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = args.seed
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

method = args.method
gamma = args.gamma
sigma = np.sqrt(2 * gamma)
batch_size = args.minibatch_size
epoches = args.epoch
n_samples = args.n_samples
reg = args.reg
M = args.M
h = args.h



def dataloader(X, y, batch_size=None):
    tensor_x = torch.stack([torch.Tensor(x) for x in X])
    tensor_y = torch.Tensor(y).view((-1, 1))
    return DataLoader(TensorDataset(tensor_x, tensor_y), batch_size=batch_size)

X = np.load('data/covtype/covtype.data.npy')
y = np.load('data/covtype/covtype.target.npy')

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
X_test = torch.stack([torch.Tensor(x) for x in X_test])
y_test = torch.Tensor(y_test).view((-1, 1))

print('number of features: {}\n'.format(X_train.shape[1]))
print('number of train data: {}'.format(len(X_train)))
print('number of test data: {}'.format(len(X_test)))

N = X_train.shape[0]
D = X_train.shape[1]

class MultipleLR(nn.Module):
    
    def __init__(self, n_samples, D, init_weight=None, init_bias=None):
        super(MultipleLR, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(D, 1) for _ in range(n_samples)])
        if init_weight is not None:
            for lr in self.linears:
                lr.weight.data = torch.tensor(init_weight, dtype=torch.float)
        if init_bias is not None:
            for lr in self.linears:
                lr.bias.data = torch.tensor(init_bias, dtype=torch.float)
    
    def forward(self, x):
        return torch.stack([l(x) for l in self.linears])
    
def evaluate_2(model, X_test, y_test):
    data, target = X_test.to(device), y_test.to(device)
    logits = model(data)
    prob = torch.mean(torch.sigmoid(logits), dim=0)
    loss = nn.BCELoss()
    llh = -loss(prob, target) # log likelihood
    pred = (prob > 0.5).float()
    correct = torch.eq(pred, target).sum()
    acc = correct.float() / len(X_test)
    return llh, acc
    
train_loader = dataloader(X_train, y_train, batch_size=batch_size)

model = MultipleLR(n_samples=n_samples, D=D).to(device)
params = [{'params': m.parameters()} for m in model.modules() if isinstance(m, nn.Linear)]

if method == 'EWSG':
    optimizer = EWSG(params, h=h, gamma=gamma, sigma=sigma)
if method == 'SGHMC':
    optimizer = SGHMC(params, h, gamma, sigma)
if method == 'SGLD':
    optimizer = SGLD(params, h)
if method == 'pSGLD':
    optimizer = pSGLD(params, h)
criterion = torch.nn.BCEWithLogitsLoss()

logger = open('results/' + f'{method}-{seed}.txt', 'w+')
logger.write('--- Experiment Configuration ---\n\n')
logger.write('dataset: covtype\n')
logger.write('number of samples: {}\n'.format(n_samples))
logger.write('minibatch_size: {}\n'.format(batch_size))
logger.write('h: {}\n'.format(h))
logger.write('reg: {}\n'.format(reg))
if method == 'SGHMC' or method == 'EWSG':
    logger.write('gamma: {}\n'.format(gamma))
    logger.write('sigma: {}\n'.format(sigma))
if method == 'EWSG':
    logger.write('M: {}\n'.format(M))

logger.write('\n' + '-' * 80 + '\n\n')

template = 'Epoch: {:.2f}\tTest-Accuracy: {:.8f}%\tLog-Likelihood: {:.8f}\n'
time_elapsed = 0

for epoch in range(epoches):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        if (epoch == 0 and idx == 0) or (idx > 0 and idx % 463 == 0):
            llh, acc = evaluate_2(model, X_test, y_test)
            info = template.format(epoch + idx/len(train_loader), acc * 100, llh)
            print(info)
            logger.write(info)
            
        start = time.time()
        
        data, target = data.to(device), target.to(device)
        batch_size = len(data)
        optimizer.zero_grad()
        logits = model(data)
        loss = sum([criterion(logit, target) for logit in logits]) * len(train_loader.dataset) / train_loader.batch_size
        if reg > 0:
            prior = reg/2 * sum([param.pow(2).sum() for param in model.parameters()])
            loss += prior
        loss.backward()
        if isinstance(optimizer, pSGLD):
            optimizer.update_preconditioner()
        
        if isinstance(optimizer, EWSG):
            if idx % (M + 1) == 0:
                optimizer.accept()
            else:
                optimizer.mh()
            if (idx + 1) % (M + 1) == 0:
                optimizer.step()
        else:
            optimizer.step()
        
        end = time.time()
        time_elapsed += end - start

logger.write('\n' + '-' * 80 + '\n\n')
logger.write('Wall time: {:.2f}s\n'.format(time_elapsed / n_samples))
logger.close()
print(f"Round {seed}: {method} finished.\n")
