import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader, Sampler
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchsummary import summary
import sys 
import argparse 
from sklearn.cluster import KMeans
from Samplers import *

parser = argparse.ArgumentParser()
parser.add_argument('--method', help='Sampler to use.')
parser.add_argument("--seed", type=int, help='Seed')
parser.add_argument('--epoch', type=int, default=200, help='Number of data passes.')
parser.add_argument("--h", type=float, help="Step size to use.")
parser.add_argument("--minibatch_size", type=int, help="Minibatch size.")
parser.add_argument("--gamma", type=float, help="gamma")
parser.add_argument("--n_samples", type=int, default=10, help="Number of ensembles.")
parser.add_argument("--reg", type=float, default=1.0, help="Regularization strength.") 
parser.add_argument("--M", type=int, default=1, help="The length of index chain in EWSG.") 
parser.add_argument("--input_path", help="Path to training and test data set.") 
parser.add_argument("--output_path", help="Path to logger.") 
parser.add_argument("--evaluation_dataset", default="test")
args = parser.parse_args()
dataset = 'MNIST'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
seed = args.seed 
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False 
method = args.method 
gamma = args.gamma 
T = 1 / 60000     # for MNIST dataset 
sigma = np.sqrt(2 * gamma * T) 
minibatch_size = args.minibatch_size 
epoches = args.epoch 
n_samples = args.n_samples

reg = args.reg
M = args.M
h = args.h


logger = open(args.output_path + '/' + args.method + f'-{seed}.txt', 'w+')
logger.write('--- Experiment Configuration ---\n\n')
logger.write('seed: {}\n'.format(seed))
logger.write('dataset: {}\n'.format(dataset))
logger.write('minibatch_size: {}\n'.format(minibatch_size))
logger.write('reg: {}\n'.format(reg))
logger.write('h: {}\n'.format(h))

if args.method not in ['SGLD', 'pSGLD']:
    logger.write('gamma: {}\n'.format(gamma))
    logger.write('sigma: {}\n'.format(sigma))
    if args.method == 'EWSG':
        logger.write('M: {}\n'.format(M))
logger.write('\n' + '-' * 80 + '\n\n')

# Clustering-based preprocessing Sampler
class CP_Sampler(Sampler):
  
    def __init__(self, data_source, batch_size=50, n_clusters=10, seed=0):
        self.data = np.array([data_source[i][0].numpy().reshape(-1) for i in range(len(data_source))])
        self.batch_size = batch_size
        self.n_clusters = n_clusters
        
        self.__kmeans__()
        self.__compute_within_cluster_variance__()
        self.__compute_batch_size_per_cluster()
        
    def __kmeans__(self):
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=seed).fit(self.data)

        # self.cluster_indices is a dictionary, key are 0-9, values are list of indices of corresponding digits
        self.cluster_indices = {}
        for index, label in enumerate(kmeans.labels_):
            if label not in self.cluster_indices:
                self.cluster_indices[label] = []
            self.cluster_indices[label].append(index)
    
    def __compute_within_cluster_variance__(self):
        # compute within-cluster variance
        self.within_variance = {}
        for label in self.cluster_indices:
            indices = self.cluster_indices[label]
            x = self.data[indices]
            x -= x.mean(axis=0)
            self.within_variance[label] = np.mean( x**2 )
      
    def __compute_batch_size_per_cluster(self):
        # computer batch size for each cluster
        self.batch_size_per_cluster = {}
        for label in self.cluster_indices:
            n = len(self.cluster_indices[label])
            v = self.within_variance[label]
            self.batch_size_per_cluster[label] = n * np.sqrt(v)
        
        _sum = sum(self.batch_size_per_cluster.values())
        for label in self.batch_size_per_cluster:
            self.batch_size_per_cluster[label] = int(round(self.batch_size_per_cluster[label] / _sum * self.batch_size))
    
    def __iter__(self):
        indices = []
        pointers = {label: 0 for label in self.cluster_indices}
        while pointers:
            to_delete = set()
            for label in pointers:
                i = pointers[label]
                b = self.batch_size_per_cluster[label]
                if i + b < len(self.cluster_indices[label]):
                    indices.extend(self.cluster_indices[label][i : i + b])
                    pointers[label] += b
                else:
                    indices.extend(self.cluster_indices[label][i : ])
                    to_delete.add(label)
            for label in to_delete:
                del pointers[label]
        
        return iter(indices)
            
    def __len__(self):
        return len(self.dat)

train_dataset = MNIST(args.input_path + '/data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) 
test_dataset = MNIST(args.input_path + '/data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) 

if method == 'CG-SGHMC':
    train_loader = DataLoader(train_dataset, batch_size=minibatch_size, sampler=CP_Sampler(train_dataset, n_clusters=10))
else:
    train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2000, shuffle=True)

class Net(nn.Module):
    
    def __init__(self, dim_input=784, dim_hidden=100, dim_output=10):
        super(Net, self).__init__()
        
        self.dim_input = dim_input
        self.fc1 = nn.Linear(dim_input, dim_hidden)
        self.fc2 = nn.Linear(dim_hidden, dim_output)

    def forward(self, x):
        if len(x.shape) > 2:
            x = x.view(-1, self.dim_input)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
class MultipleNets(nn.Module):
    
    def __init__(self, n_samples=1, dim_input=784, dim_hidden=100, dim_output=10):
        super(MultipleNets, self).__init__()
        self.n_samples = n_samples
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.dim_output = dim_output
        
        self.nets = [Net(dim_input, dim_hidden, dim_output) for _ in range(n_samples)]
        for i, net in enumerate(self.nets):
            self.add_module('net-{}'.format(i), net)
    
    def forward(self, x):
        return [net.forward(x) for net in self.nets]

# net = Net().to(device)
# summary(net, (1, 784))

def train(model, device, train_loader, criterion, optimizer):
    model.train()
    
    for idx, (data, target) in enumerate(train_loader):
        
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        logits = model(data)
        loss = sum([criterion(logit, target) for logit in logits])
        if method in ['SGLD', 'pSGLD']:
            loss *= len(train_loader.dataset)
        if reg > 0:
            prior = reg/2 * sum([param.pow(2).sum() for param in model.parameters()]) / len(train_loader.dataset)
            if method in ['SGLD', 'pSGLD']:
                prior *= len(train_loader.dataset)
            loss += prior
            
        loss.backward()
        
        if method == 'EWSG':
            if idx % (M + 1) == 0:
                optimizer.accept()
            else:
                optimizer.mh()
            
            if (idx + 1) % (M + 1) == 0:
                optimizer.step()        

        if method in ['SGLD', 'SGHMC', 'CG-SGHMC']:
            optimizer.step()

        if method == 'pSGLD':
            optimizer.update_preconditioner()
            optimizer.step()
  
def evaluate(model, device, dataset_loader, dataset):
    model.eval()
    
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataset_loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)
            prob = sum([F.softmax(logit, dim=1) for logit in logits])
            pred = prob.argmax(dim=1).long()
            correct += torch.eq(pred, target).sum()
    acc = correct.float() / len(dataset_loader.dataset)
    template = 'Epoch: {}\t{} Error: {:.2f}%\n'
    
    logger.write(template.format(epoch, dataset, (1 - acc) * 100))
    # print(template.format(epoch, (1 - acc) * 100))
    
model = MultipleNets(n_samples=n_samples).to(device)
params = [{'params': net.parameters()} for net in model.nets]
criterion = nn.CrossEntropyLoss()

if method == 'SGLD':
    optimizer = SGLD(params, h=h)
if method == 'pSGLD':
    optimizer = pSGLD(params, h=h)
if method in ['SGHMC', 'CG-SGHMC']:
    optimizer = SGHMC(params, h=h, gamma=gamma, sigma=sigma, device=device)
if method == 'EWSG':
    optimizer = EWSG(params, h=h, gamma=gamma, sigma=sigma, device=device)



for epoch in range(epoches):
    if args.evaluation_dataset == "test":
        evaluate(model, device, test_loader, "Test")
    elif args.evaluation_dataset == "train":
        evaluate(model, device, train_loader, "Train")
    elif args.evaluation_dataset == "both":
        evaluate(model, device, test_loader, "Test")
        evaluate(model, device, train_loader, "Train")

    train(model, device, train_loader, criterion, optimizer)

logger.close()

