"""Module contains functions and classes to train and evaluate a model on the CIFAR-10."""

from __future__ import annotations

import copy
import math

import torch
from .shared import batchsize, hyp, loss_fn
from torch import nn

#############################################
#            Data Preprocessing             #
#############################################


## This is actually (I believe) a pretty clean implementation of how to do something like this, since shifted-square masks unique to each depth-channel can actually be rather
## tricky in practice. That said, if there's a better way, please do feel free to submit it! This can be one of the harder parts of the code to understand (though I personally get
## stuck on the fold/unfold process for the lower-level convolution calculations.
def make_random_square_masks(inputs, mask_size, generator=None):
    ##### TODO: Double check that this properly covers the whole range of values. :'( :')
    if mask_size == 0:
        return None  # no need to cutout or do anything like that since the patch_size is set to 0
    is_even = int(mask_size % 2 == 0)
    in_shape = inputs.shape

    # seed centers of squares to cutout boxes from, in one dimension each
    mask_center_y = torch.empty(
        in_shape[0], dtype=torch.long, device=inputs.device,
    ).random_(mask_size // 2 - is_even, in_shape[-2] - mask_size // 2 - is_even, generator=generator)
    mask_center_x = torch.empty(
        in_shape[0], dtype=torch.long, device=inputs.device,
    ).random_(mask_size // 2 - is_even, in_shape[-1] - mask_size // 2 - is_even, generator=generator)

    # measure distance, using the center as a reference point
    to_mask_y_dists = torch.arange(in_shape[-2], device=inputs.device).view(
        1, 1, in_shape[-2], 1,
    ) - mask_center_y.view(-1, 1, 1, 1)
    to_mask_x_dists = torch.arange(in_shape[-1], device=inputs.device).view(
        1, 1, 1, in_shape[-1],
    ) - mask_center_x.view(-1, 1, 1, 1)

    to_mask_y = (to_mask_y_dists >= (-(mask_size // 2) + is_even)) * (
        to_mask_y_dists <= mask_size // 2
    )
    to_mask_x = (to_mask_x_dists >= (-(mask_size // 2) + is_even)) * (
        to_mask_x_dists <= mask_size // 2
    )

    final_mask = (
        to_mask_y * to_mask_x
    )  ## Turn (y by 1) and (x by 1) boolean masks into (y by x) masks through multiplication. Their intersection is square, hurray! :D

    return final_mask


def batch_cutmix(inputs, targets, patch_size, generator=None):
    with torch.no_grad():
        batch_permuted = torch.randperm(inputs.shape[0], device="cuda", generator=generator)
        cutmix_batch_mask = make_random_square_masks(inputs, patch_size, generator=generator)
        if cutmix_batch_mask is None:
            return (
                inputs,
                targets,
            )  # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutmix today.
        # We draw other samples from inside of the same batch
        cutmix_batch = torch.where(
            cutmix_batch_mask, torch.index_select(inputs, 0, batch_permuted), inputs,
        )
        cutmix_targets = torch.index_select(targets, 0, batch_permuted)
        # Get the percentage of each target to mix for the labels by the % proportion of pixels in the mix
        portion_mixed = float(patch_size**2) / (inputs.shape[-2] * inputs.shape[-1])
        cutmix_labels = portion_mixed * cutmix_targets + (1.0 - portion_mixed) * targets
        return cutmix_batch, cutmix_labels


def batch_crop(inputs, crop_size, generator=None):
    with torch.no_grad():
        crop_mask_batch = make_random_square_masks(inputs, crop_size, generator=generator)
        cropped_batch = torch.masked_select(inputs, crop_mask_batch).view(
            inputs.shape[0], inputs.shape[1], crop_size, crop_size,
        )
        return cropped_batch


def batch_flip_lr(batch_images, flip_chance=0.5, generator=None):
    with torch.no_grad():
        # TODO: Is there a more elegant way to do this? :') :'((((
        return torch.where(
            # torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance,
            # replacing to allow use of generator
            torch.empty_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)).normal_(generator=generator) < flip_chance,
            torch.flip(batch_images, (-1,)),
            batch_images,
        )

########################################
#          Training Helpers            #
########################################


class NetworkEMA(nn.Module):
    def __init__(self, net):
        super().__init__()  # init the parent module so this module is registered properly
        self.net_ema = copy.deepcopy(net).eval().requires_grad_(False)  # copy the model

    def update(self, current_net, decay):
        with torch.no_grad():
            for ema_net_parameter, (parameter_name, incoming_net_parameter) in zip(
                self.net_ema.state_dict().values(), current_net.state_dict().items(),
            ):  # potential bug: assumes that the network architectures don't change during training (!!!!)
                if incoming_net_parameter.dtype in (torch.half, torch.float):
                    ema_net_parameter.mul_(decay).add_(
                        incoming_net_parameter.detach().mul(1.0 - decay),
                    )  # update the ema values in place, similar to how optimizer momentum is coded
                    # And then we also copy the parameters back to the network, similarly to the Lookahead optimizer (but with a much more aggressive-at-the-end schedule)
                    if (
                        not ("norm" in parameter_name and "weight" in parameter_name)
                        and "whiten" not in parameter_name
                    ):
                        incoming_net_parameter.copy_(ema_net_parameter.detach())

    def forward(self, inputs):
        # with torch.no_grad():
        #     return self.net_ema(inputs)
        return self.net_ema(inputs)


# TODO: Could we jit this in the (more distant) future? :)
@torch.no_grad()
def get_batches(images: torch.Tensor, targets: torch.Tensor, crop_size: int, batchsize: int, generator=None, epoch_fraction=1.0, cutmix_size=None):
    num_epoch_examples = images.shape[0]
    shuffled = torch.randperm(num_epoch_examples, device="cuda", generator=generator)
    if epoch_fraction < 1:
        shuffled = shuffled[
            : batchsize * round(epoch_fraction * shuffled.shape[0] / batchsize)
        ]  # TODO: Might be slightly inaccurate, let's fix this later... :) :D :confetti: :fireworks:
        num_epoch_examples = shuffled.shape[0]
    ## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below
    ## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled
    ## batches (which skip the last batch if it's not a full batch), but everything seems to be (and hopefully is! :D) properly shuffled. :)
    images = batch_crop(
        images, crop_size,
    )
    images = batch_flip_lr(images, generator=generator)
    images, targets = batch_cutmix(
        images, targets, patch_size=cutmix_size,
        generator=generator,
    )

    # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training
    images = images.to(memory_format=torch.channels_last)
    for idx in range(num_epoch_examples // batchsize):
        if not (idx+1)*batchsize > num_epoch_examples:
            batch_indices = shuffled[idx*batchsize:(idx+1)*batchsize]
            yield images.index_select(0, batch_indices), targets.index_select(0, batch_indices)


def init_split_parameter_dictionaries(network):
    params_non_bias = {
        "params": [],
        "lr": hyp["opt"]["non_bias_lr"],
        "momentum": 0.85,
        "nesterov": True,
        "weight_decay": hyp["opt"]["non_bias_decay"],
        "foreach": True,
    }
    params_bias = {
        "params": [],
        "lr": hyp["opt"]["bias_lr"],
        "momentum": 0.85,
        "nesterov": True,
        "weight_decay": hyp["opt"]["bias_decay"],
        "foreach": True,
    }

    for name, p in network.named_parameters():
        if p.requires_grad:
            if "bias" in name:
                params_bias["params"].append(p)
            else:
                params_non_bias["params"].append(p)
    return params_non_bias, params_bias



def train(
    images: torch.Tensor,
    targets: torch.Tensor,
    crop_size: int,
    net: nn.Module,
    num_epochs: float,
    generator: torch.Generator=None,
) -> nn.Module:

    net_ema = None  ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
    ## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)

    current_steps = 0.0

    # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize)....
    num_steps_per_epoch = images.shape[0] // batchsize
    total_train_steps = math.ceil(num_steps_per_epoch * num_epochs)
    ema_epoch_start = (
        math.floor(num_epochs) - hyp["misc"]["ema"]["epochs"]
    )

    ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps
    ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we
    ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want.
    projected_ema_decay_val = (
        hyp["misc"]["ema"]["decay_base"] ** hyp["misc"]["ema"]["every_n_steps"]
    )

    # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for
    pct_start = hyp["opt"][
        "percent_start"
    ]  # * (total_train_steps/(total_train_steps - num_low_lr_steps_for_ema))

    ## Stowing the creation of these into a helper function to make things a bit more readable....
    non_bias_params, bias_params = init_split_parameter_dictionaries(net)

    # One optimizer for the regular network, and one for the biases. This allows us to use the superconvergence onecycle training policy for our networks....
    opt = torch.optim.SGD(**non_bias_params)
    opt_bias = torch.optim.SGD(**bias_params)

    ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 --
    ##   This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training)
    initial_div_factor = 1e16  # basically to make the initial lr ~0 or so :D
    final_lr_ratio = 0.07  # Actually pretty important, apparently!
    lr_sched = torch.optim.lr_scheduler.OneCycleLR(
        opt,
        max_lr=non_bias_params["lr"],
        pct_start=pct_start,
        div_factor=initial_div_factor,
        final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
        total_steps=total_train_steps,
        anneal_strategy="linear",
        cycle_momentum=False,
    )
    lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(
        opt_bias,
        max_lr=bias_params["lr"],
        pct_start=pct_start,
        div_factor=initial_div_factor,
        final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
        total_steps=total_train_steps,
        anneal_strategy="linear",
        cycle_momentum=False,
    )

    if True:  ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent
        for epoch in range(math.ceil(num_epochs)):
            #################
            # Training Mode #
            #################
            net.train()

            cutmix_size = (
                hyp["net"]["cutmix_size"]
                if epoch >= num_epochs - hyp["net"]["cutmix_epochs"]
                else 0
            )
            epoch_fraction = (
                1
                if epoch + 1 < num_epochs
                else num_epochs % 1
            )  # We need to know if we're running a partial epoch or not.

            for epoch_step, (batch_inputs, batch_targets) in enumerate(
                get_batches(
                    images,
                    targets,
                    crop_size=crop_size,
                    batchsize=batchsize,
                    epoch_fraction=epoch_fraction,
                    cutmix_size=cutmix_size,
                    generator=generator,
                ),
            ):
                ## Run everything through the network
                outputs = net(batch_inputs)

                loss_batchsize_scaler = (
                    512 / batchsize
                )  # to scale to keep things at a relatively similar amount of regularization when we change our batchsize since we're summing over the whole batch
                ## If you want to add other losses or hack around with the loss, you can do that here.
                loss = (
                    loss_fn(outputs, batch_targets)
                    .mul(hyp["opt"]["loss_scale_scaler"] * loss_batchsize_scaler)
                    .sum()
                    .div(hyp["opt"]["loss_scale_scaler"])
                )  ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling
                ## (and is thus batchsize dependent as a result). This can be somewhat good or bad, depending...

                loss.backward()

                ## Step for each optimizer, in turn.
                opt.step()
                opt_bias.step()

                # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch
                lr_sched.step()
                lr_sched_bias.step()

                ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method
                opt.zero_grad(set_to_none=True)
                opt_bias.zero_grad(set_to_none=True)
                current_steps += 1

                if (
                    epoch >= ema_epoch_start
                    and current_steps % hyp["misc"]["ema"]["every_n_steps"] == 0
                ):
                    ## Initialize the ema from the network at this point in time if it does not already exist.... :D
                    if net_ema is None:  # don't snapshot the network yet if so!
                        net_ema = NetworkEMA(net)
                        continue
                    # We warm up our ema's decay/momentum value over training exponentially according to the hyp config dictionary (this lets us move fast, then average strongly at the end).
                    net_ema.update(
                        net,
                        decay=projected_ema_decay_val
                        * (current_steps / total_train_steps)
                        ** hyp["misc"]["ema"]["decay_pow"],
                    )

    return net_ema
