import os
from pickletools import optimize
from statistics import variance
import yaml
import multiprocessing
import argparse
import time
import numpy as np
import random
import json

import torch
import torchvision
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn

from distributed_optimization_library.models.resnet import ResNet18
from distributed_optimization_library.models.resnet_without_bn import prepare_resnet_without_bn
from distributed_optimization_library.models.small_models import prepare_two_layer_nn
from distributed_optimization_library.function import parameters_to_tensor, tensor_to_parameters, BaseStochasticFunction

_LARGE_NUMBER = 10**12


def save_dump(path, dump):
    path_tmp = path + "_tmp_"
    with open(path_tmp, 'w') as fd:
        json.dump(dump, fd)
    os.replace(path_tmp, path)


def save_model(path, model, batch_index):
    path_tmp = path + "_tmp_model_"
    torch.save(model.state_dict(), path_tmp)
    os.replace(path_tmp, path + "_model_{}".format(batch_index))


class StochasticModel(BaseStochasticFunction):
    def __init__(self, model, loss_fn, dataset, batch_size, num_workers=1, use_cuda=False,
                 optimize_memory=True):
        self._model = model
        self._use_cuda = use_cuda
        if use_cuda:
            self._model = self._model.cuda()
        self._loss_fn = loss_fn
        self._dataset = dataset
        self._batch_size = batch_size
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=num_workers, 
            prefetch_factor=100 * num_workers,
            sampler=torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=_LARGE_NUMBER))
        self._dataloader_iter = iter(dataloader)
        self._num_workers = num_workers
        self._features_all = None
        self._optimize_memory = optimize_memory
        
        self._model.eval()
        features, _ = next(self._dataloader_iter)
        if self._use_cuda:
            features = features.cuda()
        self._model(features)
        self._model.train()

    def stochastic_gradient_at_points(self, points):
        current_seed = torch.seed()
        stochatic_gradients = []
        features, labels = next(self._dataloader_iter)
        if self._use_cuda:
            features, labels = features.cuda(), labels.cuda()
        for point in points:
            torch.manual_seed(current_seed)
            tensor_to_parameters(self._model.parameters(), point)
            loss, pred, stochatic_gradient = self._gradient(features, labels)
            stochatic_gradients.append(stochatic_gradient)
        self._last_loss = loss
        self._last_accuracy = self._accuracy(pred, labels)
        return stochatic_gradients
    
    def gradient(self, point, batch_size=None):
        tensor_to_parameters(self._model.parameters(), point)
        if not self._optimize_memory:
            if self._features_all is None:
                self._dataloader = torch.utils.data.DataLoader(self._dataset, 
                                                               batch_size=len(self._dataset), 
                                                               num_workers=self._num_workers,
                                                               prefetch_factor=100 * self._num_workers)
                for features, labels in self._dataloader:
                    pass
                if self._use_cuda:
                    self._features_all = features.cuda()
                    self._labels_all = labels.cuda()
            loss, pred, gradient = self._gradient(self._features_all, self._labels_all)
            self._last_loss = loss
            self._last_accuracy = self._accuracy(pred, self._labels_all)
        else:
            batch_size = batch_size if batch_size is not None else self._batch_size
            dataloader = torch.utils.data.DataLoader(self._dataset, 
                                                     batch_size=batch_size, 
                                                     num_workers=self._num_workers,
                                                     prefetch_factor=100 * self._num_workers)
            total_number = 0
            sum_gradients = 0
            total_loss = 0
            total_labels = []
            total_pred = []
            for features, labels in dataloader:
                if self._use_cuda:
                    features, labels = features.cuda(), labels.cuda()
                loss, pred, gradient = self._gradient(features, labels)
                total_loss += loss * len(labels)
                total_labels.append(labels)
                total_pred.append(pred)
                gradient = gradient * len(labels)
                sum_gradients = (sum_gradients + gradient).detach()
                total_number += len(labels)
            gradient = sum_gradients / total_number
            self._last_loss = total_loss / total_number
            total_labels = torch.cat(total_labels)
            total_pred = torch.cat(total_pred)
            self._last_accuracy = self._accuracy(total_pred, total_labels)
        return gradient
    
    def hessian(self, point):
        assert not self._optimize_memory
        assert self._features_all is not None
        tensor_to_parameters(self._model.parameters(), point)
        parameters = tuple(self._model.parameters())
        parameters_shape = [parameter.shape for parameter in parameters]
        parameters = tuple(torch.flatten(parameter) for parameter in parameters)
        def hack_model(*params):
            params = [torch.reshape(param, shape) for param, shape in zip(params, parameters_shape)]
            names = list(n for n, _ in self._model.named_parameters())
            pred = torch.nn.utils.stateless.functional_call(self._model, 
                                                            {n: p for n, p in zip(names, params)}, 
                                                            self._features_all)
            loss = self._loss_fn(pred, self._labels_all)
            return loss
        hessians = torch.autograd.functional.hessian(hack_model, parameters)
        hessians = [torch.concat(h, axis=1) for h in hessians]
        hessian = torch.concat(hessians, axis=0)
        return hessian

    def parameters(self):
        return self._model.parameters()
    
    def current_point(self):
        return parameters_to_tensor(self._model.parameters())
    
    def last_loss_and_accuracy(self):
        return self._last_loss, self._last_accuracy
    
    def _gradient(self, features, labels):
        pred = self._model(features)
        loss = self._loss_fn(pred, labels)
        self._model.zero_grad()
        loss.backward()
        gradient = parameters_to_tensor(self._model.parameters(), grad=True)
        return loss, pred, gradient

    def _accuracy(self, pred, labels):
        _, predicted = pred.max(1)
        accuracy = predicted.eq(labels).sum().item() / labels.size(0)
        return accuracy


class GradientDescent(object):
    def __init__(self, model, point, lr):
        self._model = model
        self._point = point
        self._learning_rate = lr
    
    def step(self):
        gradient = self._model.gradient(self._point)
        self._point = self._point - self._learning_rate * gradient
    
    def get_point(self):
        return self._point


class HolderGradientDescent(object):
    def __init__(self, model, point, lr, nu=1.0):
        self._model = model
        self._point = point
        self._learning_rate = lr
        self._nu = nu
    
    def step(self):
        gradient = self._model.gradient(self._point)
        exponent = 1 / self._nu - 1
        norm = torch.pow(torch.linalg.vector_norm(gradient), exponent)
        self._point = self._point - self._learning_rate * norm * gradient
    
    def get_point(self):
        return self._point


class AdaptiveGradientDescent(object):
    def __init__(self, model, point, lr):
        self._model = model
        self._point = point
        self._learning_rate = lr
    
    def step(self):
        prev_gradient = self._model.gradient(self._point)
        prev_func_value, _ = self._model.last_loss_and_accuracy()
        while True:
            new_point = self._point - self._learning_rate * prev_gradient
            self._model.gradient(new_point)
            func_value, _ = self._model.last_loss_and_accuracy()
            if func_value <= (prev_func_value + 
                              torch.dot(prev_gradient, new_point - self._point) + 
                              (1 / (2. * self._learning_rate)) * torch.linalg.vector_norm(new_point - self._point) ** 2):
                break
            self._learning_rate /= 2
        self._point = new_point
        self._learning_rate *= 2
    
    def get_point(self):
        return self._point


class AdaptiveHolderGradientDescent(object):
    def __init__(self, model, point, lr, nu=1.0):
        self._model = model
        self._point = point
        self._learning_rate = lr
        self._nu = nu
    
    def step(self):
        prev_gradient = self._model.gradient(self._point)
        prev_func_value, _ = self._model.last_loss_and_accuracy()
        while True:
            exponent = 1 / self._nu - 1
            norm = torch.pow(torch.linalg.vector_norm(prev_gradient), exponent)
            new_point = self._point - self._learning_rate * norm * prev_gradient
            self._model.gradient(new_point)
            func_value, _ = self._model.last_loss_and_accuracy()
            if func_value <= (prev_func_value + 
                              torch.dot(prev_gradient, new_point - self._point) + 
                              (1 / ((1 + self._nu) * (self._learning_rate ** self._nu))) * torch.linalg.vector_norm(new_point - self._point) ** (1 + self._nu)):
                break
            self._learning_rate /= 2
        self._point = new_point
        self._learning_rate *= 2
    
    def get_point(self):
        return self._point


class SGD(object):
    def __init__(self, model, point, lr, momentum=None):
        self._model = model
        self._point = point
        self._learning_rate = lr
    
    def step(self):
        stochastic_gradient = self._model.stochastic_gradient(self._point)
        self._point = self._point - self._learning_rate * stochastic_gradient
    
    def get_point(self):
        return self._point


class MomentumVarianceReduction(object):
    def __init__(self, model, point, lr, momentum):
        self._model = model
        self._point = point
        self._gradient_estimator = 0
        self._learning_rate = lr
        self._momentum = momentum
        
    def init_gradient_estimator(self):
        number_of_batches = int(1 / self._momentum)
        assert number_of_batches > 0
        self._gradient_estimator = 0
        for _ in range(number_of_batches):
            gradient = self._model.stochastic_gradient(self._point)
            self._gradient_estimator = self._gradient_estimator + gradient
        self._gradient_estimator = self._gradient_estimator / np.float32(number_of_batches)
    
    def step(self):
        previous_point = self._point
        self._point = previous_point - self._learning_rate * self._gradient_estimator
        previous_gradient, current_gradient = self._model.stochastic_gradient_at_points([previous_point, self._point])
        self._gradient_estimator = current_gradient + (1 - self._momentum) * (self._gradient_estimator - previous_gradient)

    def get_point(self):
        return self._point


def mean_and_sigma_2(model, batch_size_sigma, config):
    point = model.current_point()
    batch_size = config.get('batch_size_save_every', config['batch_size'])
    mean_gradient = model.gradient(point, batch_size=batch_size).cpu().numpy()
    norm_mean_gradient = np.linalg.norm(mean_gradient)
    estimated_variance = 0
    for _ in range(batch_size_sigma):
        gradient = model.stochastic_gradient(point).cpu().numpy()
        estimated_variance += np.linalg.norm(gradient - mean_gradient) ** 2
    estimated_variance /= batch_size_sigma
    return norm_mean_gradient, estimated_variance


def run_experiments(path_to_dataset, dumps_path, config, basename):
    if config.get('use_double', False):
        torch.set_default_tensor_type(torch.DoubleTensor)
    dump_path = os.path.join(dumps_path, basename)
    seed = config['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    model_name = config.get('model', 'resnet')
    optimize_memory = True
    if model_name == 'resnet':
        model = ResNet18(activation_name=config['resnet_params'].get('activation', 'relu'))
    elif model_name == 'resnet_without_bn':
        model = prepare_resnet_without_bn(**config['resnet_params'])
    elif model_name == 'two_layer_nn':
        model = prepare_two_layer_nn(**config['two_layer_nn_params'])
        optimize_memory = False
    cudnn.benchmark = True
    loss_fn = torch.nn.CrossEntropyLoss()
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = torchvision.datasets.CIFAR10(root=path_to_dataset, train=True, download=True,
                                            transform=transform_train)
    if config.get('first_5000_examples', False):
        trainset = torch.utils.data.Subset(trainset, range(5000))
    model_wrapper = StochasticModel(model, loss_fn, trainset, 
                                    batch_size=config['batch_size'], 
                                    num_workers=config.get('num_workers', 1),
                                    use_cuda=config.get('use_cuda', False),
                                    optimize_memory=optimize_memory)
    point = model_wrapper.current_point()
    if config['optimizer'] == 'sgd':
        optimizer_wrapper = SGD(model_wrapper, point, lr=config['learning_rate'],
                                momentum=config['momentum'])
    elif config['optimizer'] == 'mvr':
        optimizer_wrapper = MomentumVarianceReduction(model_wrapper, point, lr=config['learning_rate'],
                                                      momentum=config['momentum'])
        if config.get('init_gradient_estimator', False):
            optimizer_wrapper.init_gradient_estimator()
    elif config['optimizer'] == 'gd':
        optimizer_wrapper = GradientDescent(model_wrapper, point, lr=config['learning_rate'])
    elif config['optimizer'] == 'adaptive_gd':
        optimizer_wrapper = AdaptiveGradientDescent(model_wrapper, point, lr=10**9)
    elif config['optimizer'] == 'holder_gd':
        optimizer_wrapper = HolderGradientDescent(model_wrapper, point, lr=config['learning_rate'],
                                                  nu=config['holder_constant'])
    elif config['optimizer'] == 'adaptive_holder_gd':
        optimizer_wrapper = AdaptiveHolderGradientDescent(model_wrapper, point, lr=10**9,
                                                          nu=config['holder_constant'])
    stat = {'batch_loss': [], 'batch_accuracy': [], 'batch_norm_of_gradient': [],
            'estimate_sigma_2': [], 'estimate_norm_of_gradient': [],
            'norm_of_gradient': [], 'sigma_2': [],
            'batch_diff_gradient': [], 'batch_diff_points': [],
            'batch_dot_gradient': [], 'learning_rates': []}
    dump = {'config': config, 'stat': stat}
    batch_index = 0
    while True:
        start_time = time.time()
        if batch_index > 0:
            prev_gradient = parameters_to_tensor(model.parameters(), grad=True).cpu().numpy()
        point = optimizer_wrapper.get_point().cpu().numpy()
        optimizer_wrapper.step()
        iter_time = time.time() - start_time
        gradient = parameters_to_tensor(model.parameters(), grad=True).cpu().numpy()
        norm_of_gradient = np.linalg.norm(gradient)
        if batch_index > 0:
            diff_gradient = np.linalg.norm(gradient - prev_gradient)
            dot_gradient = np.dot(gradient, prev_gradient)
            diff_points = np.linalg.norm(point - prev_point)
        prev_point = point
        loss, accuracy = model_wrapper.last_loss_and_accuracy()
        stat['batch_loss'].append(loss.item())
        stat['batch_norm_of_gradient'].append(norm_of_gradient.item())
        stat['batch_accuracy'].append(accuracy)
        stat['learning_rates'].append(optimizer_wrapper._learning_rate)
        if batch_index > 0:
            stat['batch_diff_gradient'].append(diff_gradient.item())
            stat['batch_dot_gradient'].append(dot_gradient.item())
            stat['batch_diff_points'].append(diff_points.item())
        print("iter: {iter} / {number_of_batches} loss: {loss}, norm_of_gradient**2 :{norm_of_gradient} acc: {accuracy}, time: {iter_time}".format(
              loss=loss, iter_time=iter_time, accuracy=accuracy, norm_of_gradient=norm_of_gradient**2,
              iter=batch_index, number_of_batches=config['number_of_batches']))
        if batch_index >= config['number_of_batches']:
            save_model(dump_path, model, batch_index)
            save_dump(dump_path, dump)
            break
        if config.get('save_model_every', None) is not None and batch_index % config['save_model_every'] == 0:
            save_model(dump_path, model, batch_index)
        if batch_index % config.get('save_every', 100) == 0:
            save_dump(dump_path, dump)
            if config.get('mean_and_sigma', False):
                norm_mean_gradient, variance = mean_and_sigma_2(model_wrapper, batch_size_sigma=100, config=config)
                stat['norm_of_gradient'].append((batch_index, norm_mean_gradient.item()))
                stat['sigma_2'].append((batch_index, variance.item()))
        batch_index += 1


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dumps_path', required=True)
    parser.add_argument('--config', required=True)
    parser.add_argument('--path_to_dataset')

    args = parser.parse_args()
    basename = os.path.basename(args.config).split('.')[0]
    configs = yaml.load(open(args.config))
    assert len(configs) == 1
    config = configs[0]
    if config.get('parallel', False):
        torch.multiprocessing.set_sharing_strategy('file_system')
        if config.get('multiprocessing_spawn', False):
            multiprocessing.set_start_method("spawn")
    run_experiments(args.path_to_dataset, args.dumps_path, config, basename)

if __name__ == "__main__":
    main()
