from __future__ import annotations

import logging
# from argparse import Namespace
from functools import partial
from pathlib import Path, PosixPath

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision import transforms

# import .data as data

DTYPE = torch.float16
# set global defaults (in this particular file) for convolutions
default_conv_kwargs = {"kernel_size": 3, "padding": "same", "bias": False}

batchsize = 1024
bias_scaler = 64
# To replicate the ~95.79%-accuracy-in-110-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->90, ['ema'] epochs 10->80, cutmix_size 3->10, and cutmix_epochs 6->80
hyp = {
    "opt": {
        "bias_lr": 1.525
        * bias_scaler
        / 512,  # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
        "non_bias_lr": 1.525 / 512,
        "bias_decay": 6.687e-4 * batchsize / bias_scaler,
        "non_bias_decay": 6.687e-4 * batchsize,
        "scaling_factor": 1.0 / 9,
        "percent_start": 0.23,
        "loss_scale_scaler": 1.0
        / 32,  # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
    },
    "net": {
        "whitening": {
            "kernel_size": 2,
            "num_examples": 50000,
        },
        "batch_norm_momentum": 0.4,  # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
        "cutmix_size": 3,
        "cutmix_epochs": 6,
        "pad_amount": 2,
        # "pad_amount": 0,
        # "base_depth": 64,  ## This should be a factor of 8 in some way to stay tensor core friendly
    },
    "misc": {
        "ema": {
            "epochs": 10,  # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
            "decay_base": 0.95,
            "decay_pow": 3.0,
            "every_n_steps": 5,
        },
    },
}

# Bad stuff that's just here for compatibility but should NEVER NEVER NEVER EVER be used for evals! !! !!!
bad_hyp = {
    "misc": {
        "data_location": "data.pt",
        "device": "cuda",
    },
}


#############################################
#                Dataloader                 #
#############################################

def get_data_dict(canary_index: int | None = None,
                  canary_image: torch.Tensor | None = None,
                  caching: bool = False) -> dict:
    """Load and preprocess CIFAR-10 dataset, optionally injecting a canary image.

    Parameters
    ----------
    canary_index : int | None, optional
        Index at which to inject the canary image, by default None.
    canary_image : torch.Tensor | None, optional
        Canary image to inject, by default None.
    caching : bool, optional
        Whether to use cached data if available, by default False.

    Returns
    -------
    dict
        A dictionary containing preprocessed training and evaluation datasets.

    """
    if not caching or (not Path.exists(PosixPath(bad_hyp["misc"]["data_location"]))):
        transform = transforms.Compose([transforms.ToTensor()])
        # use the dataloader to get a single batch of all of the dataset items at once.
        cifar10 = torchvision.datasets.CIFAR10(
            "cifar10/", download=True, train=True, transform=transform)
        cifar10_eval = torchvision.datasets.CIFAR10(
            "cifar10/", download=False, train=False, transform=transform)

        train_dataset_gpu_loader = torch.utils.data.DataLoader(
            cifar10,
            batch_size=len(cifar10),
            drop_last=True,
            shuffle=True,
            num_workers=2,
            persistent_workers=False,
        )
        eval_dataset_gpu_loader = torch.utils.data.DataLoader(
            cifar10_eval,
            batch_size=len(cifar10_eval),
            drop_last=True,
            shuffle=False,
            num_workers=1,
            persistent_workers=False,
        )

        train_dataset_gpu = {}
        eval_dataset_gpu = {}

        print("Loading training data")
        train_dataset_gpu["images"], train_dataset_gpu["targets"] = (
            item.to(device=bad_hyp["misc"]["device"], non_blocking=True)
            for item in next(iter(train_dataset_gpu_loader))
        )
        print("Loading test/eval data")
        eval_dataset_gpu["images"], eval_dataset_gpu["targets"] = (
            item.to(device=bad_hyp["misc"]["device"], non_blocking=True)
            for item in next(iter(eval_dataset_gpu_loader))
        )

        cifar10_std, cifar10_mean = torch.std_mean(
            train_dataset_gpu["images"], dim=(0, 2, 3),
        )  # dynamically calculate the std and mean from the data. this shortens the code and should help us adapt to new datasets!

        def batch_normalize_images(input_images, mean, std):
            return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)

        # preload with our mean and std
        batch_normalize_images = partial(
            batch_normalize_images, mean=cifar10_mean, std=cifar10_std,
        )

        ## Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now.
        train_dataset_gpu["images"] = batch_normalize_images(train_dataset_gpu["images"])
        eval_dataset_gpu["images"] = batch_normalize_images(eval_dataset_gpu["images"])

        data = {
            "train": train_dataset_gpu,
            "eval": eval_dataset_gpu,
        }

        if canary_index is not None and canary_image is not None:
            logging.info("Injecting canary image into index {index}",
                         index=canary_index)
            data["train"]["images"][canary_index] = canary_image
        elif sum([canary_index is None, canary_image is None]) == 1:
            msg = "Both `canary_index` and `canary_image` should be provided."
            raise RuntimeError(msg)

        ## Convert dataset to FP16 now for the rest of the process....
        data["train"]["images"] = data["train"]["images"].half().requires_grad_(False)
        data["eval"]["images"] = data["eval"]["images"].half().requires_grad_(False)

        # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!)
        data["train"]["targets"] = F.one_hot(data["train"]["targets"]).half()
        data["eval"]["targets"] = F.one_hot(data["eval"]["targets"]).half()

        torch.save(data, bad_hyp["misc"]["data_location"])
    else:
        ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
        ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
        ## hyp dictionary, then we should be good. :)
        logging.warning("Reusing data from disk")
        data = torch.load(bad_hyp["misc"]["data_location"])
        if canary_index is not None and canary_image is not None:
            logging.info("Injecting canary image into index {index}",
                         index=canary_index)
            data["train"]["images"][canary_index] = canary_image
        elif sum([canary_index is None, canary_image is None]) == 1:
            msg = "Both `canary_index` and `canary_image` should be provided."
            raise RuntimeError(msg)

    ## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way
    ## of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us.

    # Pad the GPU training dataset
    if hyp["net"]["pad_amount"] > 0:
        logging.warning('Using padding. If not wanted, set `hyp["net"]["pad_amount"] = 0`')
        ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
        data["train"]["images"] = F.pad(
            data["train"]["images"], (hyp["net"]["pad_amount"],) * 4, "reflect",
        )
    return data


#############################################
#          Init Helper Functions            #
#############################################


def get_patches(x, patch_shape=(3, 3), dtype=torch.float32):
    # This uses the unfold operation (https://pytorch.org/docs/stable/generated/torch.nn.functional.unfold.html?highlight=unfold#torch.nn.functional.unfold)
    # to extract a _view_ (i.e., there's no data copied here) of blocks in the input tensor. We have to do it twice -- once horizontally, once vertically. Then
    # from that, we get our kernel_size*kernel_size patches to later calculate the statistics for the whitening tensor on :D
    c, (h, w) = x.shape[1], patch_shape
    return (
        x.unfold(2, h, 1).unfold(3, w, 1).transpose(1, 3).reshape(-1, c, h, w).to(dtype)
    )  # TODO: Annotate?


def get_whitening_parameters(patches):
    # As a high-level summary, we're basically finding the high-dimensional oval that best fits the data here.
    # We can then later use this information to map the input information to a nicely distributed sphere, where also
    # the most significant features of the inputs each have their own axis. This significantly cleans things up for the
    # rest of the neural network and speeds up training.
    n, c, h, w = patches.shape
    est_covariance = torch.cov(patches.view(n, c * h * w).t())
    eigenvalues, eigenvectors = torch.linalg.eigh(
        est_covariance, UPLO="U",
    )  # this is the same as saying we want our eigenvectors, with the specification that the matrix be an upper triangular matrix (instead of a lower-triangular matrix)
    return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.t().reshape(
        c * h * w, c, h, w,
    ).flip(0)


# Run this over the training set to calculate the patch statistics, then set the initial convolution as a non-learnable 'whitening' layer
def init_whitening_conv(
    layer,
    train_set=None,
    num_examples=None,
    previous_block_data=None,
    pad_amount=None,
    freeze=True,
    whiten_splits=None,
):
    if train_set is not None and previous_block_data is None:
        if pad_amount > 0:
            previous_block_data = train_set[
                :num_examples, :, pad_amount:-pad_amount, pad_amount:-pad_amount,
            ]  # if it's none, we're at the beginning of our network.
        else:
            previous_block_data = train_set[:num_examples, :, :, :]

    # chunking code to save memory for smaller-memory-size (generally consumer) GPUs
    if whiten_splits is None:
        previous_block_data_split = [
            previous_block_data,
        ]  # If we're whitening in one go, then put it in a list for simplicity to reuse the logic below
    else:
        previous_block_data_split = previous_block_data.split(
            whiten_splits, dim=0,
        )  # Otherwise, we split this into different chunks to keep things manageable

    eigenvalue_list, eigenvector_list = [], []
    for data_split in previous_block_data_split:
        eigenvalues, eigenvectors = get_whitening_parameters(
            get_patches(data_split, patch_shape=layer.weight.data.shape[2:]),
        )
        eigenvalue_list.append(eigenvalues)
        eigenvector_list.append(eigenvectors)

    eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
    eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0)
    # i believe the eigenvalues and eigenvectors come out in float32 for this because we implicitly cast it to float32 in the patches function (for numerical stability)
    set_whitening_conv(
        layer,
        eigenvalues.to(dtype=layer.weight.dtype),
        eigenvectors.to(dtype=layer.weight.dtype),
        freeze=freeze,
    )
    data = layer(previous_block_data.to(dtype=layer.weight.dtype))
    return data


def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True):
    shape = conv_layer.weight.data.shape
    eigenvectors_sliced = (eigenvectors / torch.sqrt(eigenvalues + eps))[
        -shape[0] :, :, :, :,
    ]  # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors
    conv_layer.weight.data = torch.cat(
        (eigenvectors_sliced, -eigenvectors_sliced), dim=0,
    )
    ## We don't want to train this, since this is implicitly whitening over the whole dataset
    ## For more info, see David Page's original blogposts (link in the README.md as of this commit.)
    if freeze:
        conv_layer.weight.requires_grad = False


loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2, reduction="none")
