import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader, Sampler
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchsummary import summary
import sys
import argparse
from sklearn.cluster import KMeans
from Samplers import *

parser = argparse.ArgumentParser()
parser.add_argument('--method', help='Sampler to use.')
parser.add_argument("--seed", type=int, help='Seed')
parser.add_argument('--epoch', type=int, default=200, help='Number of data passes.')
parser.add_argument("--h", type=float, help="Step size to use.")
parser.add_argument("--minibatch_size", type=int, help="Minibatch size.")
parser.add_argument("--gamma", type=float, help="gamma")
parser.add_argument("--n_samples", type=int, default=10, help="Number of ensembles.")
parser.add_argument("--reg", type=float, default=1.0, help="Regularization strength.")
parser.add_argument("--M", type=int, default=1, help="The length of index chain in EWSG.")
parser.add_argument("--input_path", help="Path to training and test data set.")
parser.add_argument("--output_path", help="Path to logger.")
parser.add_argument("--evaluation_dataset", default="test")

args = parser.parse_args()
dataset = 'MNIST'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = args.seed
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
method = args.method
gamma = args.gamma
T = 1 / 60000     # for MNIST dataset
sigma = np.sqrt(2 * gamma * T)
minibatch_size = args.minibatch_size
epoches = args.epoch
n_samples = args.n_samples

reg = args.reg
M = args.M
h = args.h


logger = open(args.output_path + '/' + args.method + f'-{seed}.txt', 'w+')
logger.write('--- Experiment Configuration ---\n\n')
logger.write('seed: {}\n'.format(seed))
logger.write('dataset: {}\n'.format(dataset))
logger.write('minibatch_size: {}\n'.format(minibatch_size))
logger.write('reg: {}\n'.format(reg))
logger.write('h: {}\n'.format(h))

if args.method not in ['SGLD', 'pSGLD']:
    logger.write('gamma: {}\n'.format(gamma))
    logger.write('sigma: {}\n'.format(sigma))
    if args.method == 'EWSG':
        logger.write('M: {}\n'.format(M))
logger.write('\n' + '-' * 80 + '\n\n')

# Clustering-based preprocessing Sampler
class CP_Sampler(Sampler):

    def __init__(self, data_source, batch_size=50, n_clusters=10, seed=0):
        self.data = np.array([data_source[i][0].numpy().reshape(-1) for i in range(len(data_source))])
        self.batch_size = batch_size
        self.n_clusters = n_clusters

        self.__kmeans__()
        self.__compute_within_cluster_variance__()
        self.__compute_batch_size_per_cluster()

    def __kmeans__(self):
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=seed).fit(self.data)

        # self.cluster_indices is a dictionary, key are 0-9, values are list of indices of corresponding digits
        self.cluster_indices = {}
        for index, label in enumerate(kmeans.labels_):
            if label not in self.cluster_indices:
                self.cluster_indices[label] = []
            self.cluster_indices[label].append(index)

    def __compute_within_cluster_variance__(self):
        # compute within-cluster variance
        self.within_variance = {}
        for label in self.cluster_indices:
            indices = self.cluster_indices[label]
            x = self.data[indices]
            x -= x.mean(axis=0)
            self.within_variance[label] = np.mean( x**2 )

    def __compute_batch_size_per_cluster(self):
        # computer batch size for each cluster
        self.batch_size_per_cluster = {}
        for label in self.cluster_indices:
            n = len(self.cluster_indices[label])
            v = self.within_variance[label]
            self.batch_size_per_cluster[label] = n * np.sqrt(v)

        _sum = sum(self.batch_size_per_cluster.values())
        for label in self.batch_size_per_cluster:
            self.batch_size_per_cluster[label] = int(round(self.batch_size_per_cluster[label] / _sum * self.batch_size))

    def __iter__(self):
        indices = []
        pointers = {label: 0 for label in self.cluster_indices}
        while pointers:
            to_delete = set()
            for label in pointers:
                i = pointers[label]
                b = self.batch_size_per_cluster[label]
                if i + b < len(self.cluster_indices[label]):
                    indices.extend(self.cluster_indices[label][i : i + b])
                    pointers[label] += b
                else:
                    indices.extend(self.cluster_indices[label][i : ])
                    to_delete.add(label)
            for label in to_delete:
                del pointers[label]

        return iter(indices)

    def __len__(self):
        return len(self.dat)

train_dataset = MNIST(args.input_path + '/data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = MNIST(args.input_path + '/data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

if method == 'CG-SGHMC':
    train_loader = DataLoader(train_dataset, batch_size=minibatch_size, sampler=CP_Sampler(train_dataset, n_clusters=10))
else:
    train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2000, shuffle=True)

class Net(nn.Module):

    def __init__(self, n_classes=10):
        super(Net, self).__init__()

        self.conv1 = nn.Sequential(
                                    nn.Conv2d(1, 32, 5),
                                    nn.MaxPool2d(2),
                                    nn.ReLU()
                                    )
        self.conv2 = nn.Sequential(
                                    nn.Conv2d(32, 64, 5),
                                    nn.MaxPool2d(2),
                                    nn.ReLU(),
                                    )
        self.fc1 = nn.Sequential(
                                    nn.Linear(64 * 4 * 4, 200),
                                    nn.ReLU()
                                    )
        self.fc2 = nn.Sequential(
                                    nn.Linear(200, 200),
                                    nn.ReLU()
                                    )
        self.fc3 = nn.Sequential(
                                    nn.Linear(200, n_classes),
                                    nn.ReLU()
                                    )


    def forward(self, x):
        if len(x.shape) != 4:
            x = x.view((len(x), -1, 28, 28))
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view((len(x), -1))
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

class MultipleNets(nn.Module):

    def __init__(self, n_samples=1, n_classes=10):
        super(MultipleNets, self).__init__()
        self.n_samples = n_samples
        self.n_classes = n_classes

        self.nets = [Net(self.n_classes) for _ in range(n_samples)]
        for i, net in enumerate(self.nets):
            self.add_module('net-{}'.format(i), net)

    def forward(self, x):
        return [net.forward(x) for net in self.nets]

# net = Net().to(device)
# summary(net, (1, 784))

def train(model, device, train_loader, criterion, optimizer):
    model.train()

    for idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        logits = model(data)
        loss = sum([criterion(logit, target) for logit in logits])
        if method in ['SGLD', 'pSGLD']:
            loss *= len(train_loader.dataset)
        if reg > 0:
            prior = reg/2 * sum([param.pow(2).sum() for param in model.parameters()]) / len(train_loader.dataset)
            if method in ['SGLD', 'pSGLD']:
                prior *= len(train_loader.dataset)
            loss += prior

        loss.backward()

        if method == 'EWSG':
            if idx % (M + 1) == 0:
                optimizer.accept()
            else:
                optimizer.mh()

            if (idx + 1) % (M + 1) == 0:
                optimizer.step()

        if method in ['SGLD', 'SGHMC', 'CG-SGHMC']:
            optimizer.step()

        if method == 'pSGLD':
            optimizer.update_preconditioner()
            optimizer.step()

def evaluate(model, device, dataset_loader, dataset):
    model.eval()

    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataset_loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)
            prob = sum([F.softmax(logit, dim=1) for logit in logits])
            pred = prob.argmax(dim=1).long()
            correct += torch.eq(pred, target).sum()
    acc = correct.float() / len(dataset_loader.dataset)
    template = 'Epoch: {}\t{} Error: {:.2f}%\n'

    logger.write(template.format(epoch, dataset, (1 - acc) * 100))
    # print(template.format(epoch, (1 - acc) * 100))

model = MultipleNets(n_samples=n_samples).to(device)
params = [{'params': net.parameters()} for net in model.nets]
criterion = nn.CrossEntropyLoss()

if method == 'SGLD':
    optimizer = SGLD(params, h=h)
if method == 'pSGLD':
    optimizer = pSGLD(params, h=h)
if method in ['SGHMC', 'CG-SGHMC']:
    optimizer = SGHMC(params, h=h, gamma=gamma, sigma=sigma, device=device)
if method == 'EWSG':
    optimizer = EWSG(params, h=h, gamma=gamma, sigma=sigma, device=device)
    train(model, device, train_loader, criterion, optimizer)

for epoch in range(epoches):
    if args.evaluation_dataset == "test":
        evaluate(model, device, test_loader, "Test")
    elif args.evaluation_dataset == "train":
        evaluate(model, device, train_loader, "Train")
    elif args.evaluation_dataset == "both":
        evaluate(model, device, test_loader, "Test")
        evaluate(model, device, train_loader, "Train")

    train(model, device, train_loader, criterion, optimizer)


logger.close()