import contextlib
import json
import math
import numpy as np
import os
import pathlib
import time
import signatory
import torch
import torch.optim.swa_utils
import tqdm

from .. import models
from . import common

_here = pathlib.Path(__file__).resolve().parent


def _loop(iterable):
    while True:
        yield from iterable


def _train_generator(t, model, device, optimizer, discrim_optimizer, y):
    y = y.to(device)
    gen_loss = model.train_generator(t, y)
    gen_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    discrim_optimizer.zero_grad()


def _train_disciminator(t, model, device, optimizer, discrim_optimizer, x, y):
    x = x.to(device)
    y = y.to(device)
    loss, penalty = model.train_discriminator(t, x, y, penalty=True)
    if model.model_type == 'gan':
        (penalty - loss).backward()  # -loss because we're trying to maximise that with the discriminator
    else:
        (penalty + loss).backward()
    discrim_optimizer.step()
    optimizer.zero_grad()
    discrim_optimizer.zero_grad()
    return loss.item()


def _train_loop(t, dataloader, model, device, optimizer, discrim_optimizer, pre_epochs, epochs, ratio, print_callback,
                epoch_per_metric, averaging):
    history = []

    start_time = time.time()
    infinite_dataloader = _loop(dataloader)

    averaged_model = common.AttrDict(module=model)

    # Pretrain discriminator
    try:
        print('Starting training for model:\n\n' + str(model) + '\n\n')
        if model.model_type == 'gan':
            tqdm_range = tqdm.tqdm(range(pre_epochs))
            tqdm_range.write("Pre-training discriminator")
            for epoch in tqdm_range:
                for x, y in dataloader:
                    loss = _train_disciminator(t, model, device, optimizer, discrim_optimizer, x, y)
                tqdm_range.write("Pretraining Epoch: {}  Loss: {:.5}".format(epoch, loss))
            averaged_model = torch.optim.swa_utils.AveragedModel(model)
            if averaging == 0:
                averaged_model.update_parameters(model)

        tqdm_range = tqdm.tqdm(range(epochs))
        tqdm_range.write("Training both generator and discriminator")
        for epoch in tqdm_range:
            total_loss = 0
            total_dataset_size = 0
            for x, y in dataloader:
                if model.model_type == 'gan':
                    _train_generator(t, model, device, optimizer, discrim_optimizer, y)
                    for _ in range(ratio):
                        x, y = next(infinite_dataloader)
                        loss = _train_disciminator(t, model, device, optimizer, discrim_optimizer, x, y)
                else:
                    loss = _train_disciminator(t, model, device, optimizer, discrim_optimizer, x, y)
                batch_size = y.size(0)
                total_loss += loss * batch_size
                total_dataset_size += batch_size
            total_loss /= total_dataset_size

            if epoch >= averaging and model.model_type == 'gan':
                averaged_model.update_parameters(model)

            if epoch % epoch_per_metric == 0 or epoch == epochs - 1:
                with torch.no_grad():
                    total_averaged_loss = 0
                    total_averaged_dataset_size = 0
                    for x, y in dataloader:
                        x = x.to(device)
                        y = y.to(device)
                        loss, _ = averaged_model.module.train_discriminator(t, x, y, penalty=False)
                        batch_size = y.size(0)
                        total_averaged_loss += loss.item() * batch_size
                        total_averaged_dataset_size += batch_size
                    total_averaged_loss /= total_averaged_dataset_size
                tqdm_range.write('Epoch: {}  Loss: {:.5}  Average Loss: {:.5}'.format(epoch,
                                                                                      total_loss,
                                                                                      total_averaged_loss))
                history.append(common.AttrDict(epoch=epoch,
                                               total_loss=total_loss,
                                               total_averaged_loss=total_averaged_loss))
                print_callback(averaged_model.module)
    except BaseException as e:
        e_ = e
    else:
        e_ = None
    if model.model_type == 'gan':
        model.load_state_dict(averaged_model.module.state_dict())
    timespan = time.time() - start_time
    return history, start_time, timespan, e_


def _evaluate_classification(t, dataloader, model, device, input_channels, classification_epochs, classification_lr,
                             classification_kwargs, classification_plateau_terminate):
    # Rewrap dataset to split into train/test
    dataset = dataloader.dataset
    generator = torch.Generator().manual_seed(23456789)
    perm = torch.randperm(len(dataset), generator=generator)
    train = int(0.8 * len(dataset))
    train_indices, test_indices = perm[:train], perm[train:]
    sharing_strategy = torch.multiprocessing.get_sharing_strategy()
    torch.multiprocessing.set_sharing_strategy('file_system')
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=dataloader.batch_size,
                                                   num_workers=6, pin_memory=True, generator=generator,
                                                   sampler=torch.utils.data.SubsetRandomSampler(train_indices,
                                                                                                generator=generator))
    test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=dataloader.batch_size,
                                                  num_workers=6, pin_memory=True, generator=generator,
                                                  sampler=torch.utils.data.SubsetRandomSampler(test_indices,
                                                                                               generator=generator))

    classifier = models.NeuralCDE(in_size=input_channels, out_size=1, **classification_kwargs).to(device)
    optimiser = torch.optim.Adam(classifier.parameters(), lr=classification_lr)

    tqdm_range = tqdm.tqdm(range(classification_epochs))
    tqdm_range.write("Training real/fake classifier")
    best_loss_epoch = 0
    best_loss = math.inf
    for epoch in tqdm_range:
        total_dataset_size = 0
        total_loss = 0
        for x, y in train_dataloader:
            y = y.to(device)
            batch_size = y.size(0)
            total_dataset_size += 2 * batch_size  # 2 for real+fake

            real_x = x.to(device)
            real_score = classifier(real_x)
            real_loss = torch.nn.functional.binary_cross_entropy_with_logits(real_score, torch.ones_like(real_score))
            real_loss.backward()
            total_loss += real_loss.item() * batch_size
            del real_x, real_score, real_loss

            fake_x = model.generate_sample(t, y)
            fake_score = classifier(fake_x)
            fake_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_score, torch.zeros_like(fake_score))
            fake_loss.backward()
            total_loss += fake_loss.item() * batch_size
            del fake_x, fake_score, fake_loss

            optimiser.step()
            optimiser.zero_grad()

        total_loss /= total_dataset_size

        tqdm_range.write('Epoch: {}  Loss: {:.5}'.format(epoch, total_loss))

        if total_loss * 1.001 < best_loss:
            best_loss_epoch = epoch
            best_loss = total_loss

        if epoch > best_loss_epoch + classification_plateau_terminate:
            tqdm_range.write("Breaking because of no loss improvement")
            break

    with torch.no_grad():
        total_dataset_size = 0
        classification_loss = 0
        total_correct = 0
        for x, y in test_dataloader:
            y = y.to(device)
            batch_size = y.size(0)
            total_dataset_size += 2 * batch_size  # 2 for real+fake

            real_x = x.to(device)
            real_score = classifier(real_x)
            real_loss = torch.nn.functional.binary_cross_entropy_with_logits(real_score, torch.ones_like(real_score))

            fake_x = model.generate_sample(t, y)
            fake_score = classifier(fake_x)
            fake_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_score, torch.zeros_like(fake_score))

            classification_loss += (real_loss + fake_loss).item() * batch_size
            total_correct += (real_score > 0).sum().item() + (fake_score < 0).sum().item()

    classification_loss /= total_dataset_size
    classification_accuracy = total_correct / total_dataset_size

    torch.multiprocessing.set_sharing_strategy(sharing_strategy)

    return classification_loss, classification_accuracy


def _evaluate_prediction(t, dataloader, model, device, input_channels, prediction_epochs, prediction_lr,
                         prediction_kwargs, prediction_plateau_terminate, prediction_split):
    predictor = models.Seq2Seq(split=prediction_split, in_size=input_channels, **prediction_kwargs).to(device)
    optimiser = torch.optim.Adam(predictor.parameters(), lr=prediction_lr)

    tqdm_range = tqdm.tqdm(range(prediction_epochs))
    tqdm_range.write("Training seq2seq predictor")
    best_loss_epoch = 0
    best_loss = math.inf
    for epoch in tqdm_range:
        total_dataset_size = 0
        total_loss = 0
        for _, y in dataloader:
            y = y.to(device)
            batch_size = y.size(0)
            total_dataset_size += batch_size

            fake_x = model.generate_sample(t, y)
            fake_loss = predictor(t, fake_x)
            fake_loss.backward()
            total_loss += fake_loss.item() * batch_size

            optimiser.step()
            optimiser.zero_grad()

        total_loss /= total_dataset_size

        tqdm_range.write('Epoch: {}  Loss: {:.5}'.format(epoch, total_loss))

        if total_loss * 1.001 < best_loss:
            best_loss_epoch = epoch
            best_loss = total_loss

        if epoch > best_loss_epoch + prediction_plateau_terminate:
            tqdm_range.write("Breaking because of no loss improvement")
            break

    with torch.no_grad():
        total_dataset_size = 0
        prediction_loss = 0
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            batch_size = y.size(0)
            total_dataset_size += batch_size

            loss = predictor(t, x)
            prediction_loss += loss.item() * batch_size

        prediction_loss /= total_dataset_size

    return prediction_loss


def _evaluate_mmd(t, dataloader, model, device):
    with torch.no_grad():
        total_dataset_size = 0
        loss = 0
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            total_dataset_size += y.size(0)

            t_ = t.unsqueeze(0).unsqueeze(-1).expand(x.size(0), x.size(1), 1)
            x = torch.cat([t_, x], dim=2)
            fake_x = model.generate_sample(t, y)
            fake_x = torch.cat([t_, fake_x], dim=2)

            sig_x = signatory.signature(x, 5).sum(dim=0)
            sig_fake_x = signatory.signature(fake_x, 5).sum(dim=0)

            loss += sig_x - sig_fake_x

        loss /= total_dataset_size
    # Scaling is obviously an arbitrary part of the definition.
    # So we just include a 1/10000 scaling to get numbers that fit on a page.
    return (loss.abs().mean() / 10000).item()


class _TensorEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, (torch.Tensor, np.ndarray)):
            return o.tolist()
        else:
            super(_TensorEncoder, self).default(o)


def _save_results(name, result):
    loc = _here / '../../results' / name
    os.makedirs(loc, exist_ok=True)
    num = -1
    for filename in os.listdir(loc):
        try:
            num = max(num, int(filename))
        except ValueError:
            pass
    result_to_save = result.copy()
    del result_to_save['dataloader']
    model = result_to_save['model']
    result_to_save['model'] = str(result_to_save['model'])

    num += 1
    with open(loc / str(num), 'w') as f:
        json.dump(result_to_save, f, cls=_TensorEncoder)

    torch.save(model.state_dict(), loc / (str(num) + '_model'))


def main(name, t, dataloader, model, device, save, pre_epochs, epochs, ratio, generator_lr, discriminator_lr,
         print_callback, epoch_per_metric, input_channels,
         classification_epochs, classification_lr, classification_kwargs, classification_plateau_terminate,
         prediction_epochs, prediction_lr, prediction_kwargs, prediction_plateau_terminate, prediction_split,
         averaging):
    # makes pin_memory pin to the correct GPU
    with torch.cuda.device(device) if device != 'cpu' else contextlib.nullcontext():
        if device != 'cpu':
            torch.cuda.reset_peak_memory_stats()
            baseline_memory = torch.cuda.memory_allocated()
        else:
            baseline_memory = None

        t = t.to(device)
        model.to(device)

        optimizer = model.generator_optimiser(generator_lr)
        discrim_optimizer = model.discriminator_optimiser(discriminator_lr)

        history, start_time, timespan, e = _train_loop(t, dataloader, model, device, optimizer, discrim_optimizer,
                                                       pre_epochs, epochs, ratio, print_callback, epoch_per_metric,
                                                       averaging)

        if device != 'cpu':
            memory_usage = torch.cuda.max_memory_allocated() - baseline_memory
        else:
            memory_usage = None

    if e is None:
        classification_loss, classification_accuracy = _evaluate_classification(t, dataloader, model, device,
                                                                                input_channels,
                                                                                classification_epochs,
                                                                                classification_lr,
                                                                                classification_kwargs,
                                                                                classification_plateau_terminate)
        prediction_loss = _evaluate_prediction(t, dataloader, model, device, input_channels, prediction_epochs,
                                               prediction_lr, prediction_kwargs, prediction_plateau_terminate,
                                               prediction_split)
        mmd_loss = _evaluate_mmd(t, dataloader, model, device)
    else:
        classification_loss = None
        classification_accuracy = None
        prediction_loss = None
        mmd_loss = None
    evaluate = common.AttrDict(classification_loss=classification_loss,
                               classification_accuracy=classification_accuracy,
                               prediction_loss=prediction_loss,
                               mmd_loss=mmd_loss)

    result = common.AttrDict(memory_usage=memory_usage,
                             baseline_memory=baseline_memory,
                             start_time=start_time,
                             timespan=timespan,
                             model=model.to('cpu'),
                             parameters=common.count_parameters(model),
                             history=history,
                             evaluate=evaluate,
                             dataloader=dataloader)
    if save:
        _save_results(name, result)

    if e is not None:
        raise e
    return result
