from functools import partial
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import os
from tqdm import tqdm
import random

import jax

from utils_pcax.models import energy_per_data, infer_on_batch


# fast dataloader for mlp
def get_datax_mlp(args):
    n_train, n_val, n_test = args.train_size, args.val_size, args.test_size
    dataset = args.dataset
    batch_size = args.batch_size
    is_supervised = args.is_supervised
    latent_dim = args.latent_dim
    is_cnn = False if not hasattr(args, "is_cnn") else args.is_cnn

    if isinstance(args.data_dim, int) or len(args.data_dim) == 1:
        data_shape = (784,)
    else:
        channels = 3 if args.dataset == "cifar10" else 1
        data_shape = (channels, *args.data_dim)

    # Define the transformation to scale pixels to the range [-1, 1]
    cifar_train = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    cifar_eval = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    ist_transform = [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        *([transforms.Resize((32, 32))] if is_cnn else []),
    ]

    train_transform = transforms.Compose(
        cifar_train if dataset == "cifar10" else ist_transform
    )
    eval_transform = transforms.Compose(
        cifar_eval if dataset == "cifar10" else ist_transform
    )
    if dataset == "mnist":
        train_dataset = datasets.MNIST(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.MNIST(
            "./data", download=True, train=False, transform=eval_transform
        )
    elif dataset == "cifar10":
        train_dataset = datasets.CIFAR10(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.CIFAR10(
            "./data", download=True, train=False, transform=eval_transform
        )
    elif dataset == "fashion_mnist":
        train_dataset = datasets.FashionMNIST(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.FashionMNIST(
            "./data", download=True, train=False, transform=eval_transform
        )

    # Randomly sample the train dataset
    if n_train < 100 and n_train % 10 == 0:
        # sort the dataset by label to have labels [0, 1, 2, ..., 9] in the first n_train samples
        indices = []
        label_counts = {i: 0 for i in range(10)}
        for i in range(len(train_dataset)):
            label = train_dataset[i][1]
            if label_counts[label] < n_train // 10:
                indices.append(i)
                label_counts[label] += 1
        train_dataset = torch.utils.data.Subset(train_dataset, indices)
    else:
        train_dataset = torch.utils.data.Subset(
            train_dataset, random.sample(range(len(train_dataset)), n_train)
        )

    # Randomly sample the val dataset
    indices = list(range(len(eval_dataset)))
    random.Random(42).shuffle(indices)
    eval_dataset = Subset(eval_dataset, indices)
    val_dataset, test_dataset, not_used = [
        Subset(eval_dataset, range(start, end))
        for start, end in [
            (0, n_val),
            (n_val, n_val + n_test),
            (n_val + n_test, len(eval_dataset)),
        ]
    ]

    def process_dataset(dataset, is_supervised):
        data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        data, label = list(data_loader)[0]
        nm_elements = len(data)
        X = (
            jax.nn.one_hot(label.numpy(), latent_dim)
            if is_supervised
            else np.zeros((nm_elements, latent_dim))
        )
        X = X[: batch_size * (nm_elements // batch_size)]
        y = data.numpy()[: batch_size * (nm_elements // batch_size)]

        return list(
            zip(
                X.reshape(-1, batch_size, latent_dim),
                y.reshape(-1, batch_size, *data_shape),
            )
        )

    # Process and move datasets to GPU
    train_dl = process_dataset(train_dataset, is_supervised)
    val_dl = process_dataset(val_dataset, is_supervised)
    test_dl = process_dataset(test_dataset, is_supervised)

    return train_dl, val_dl, test_dl


import numpy as np


def one_hot(labels, num_classes):
    # labels: array of integers [N] where each integer is in [0, num_classes-1]
    # num_classes: total number of distinct classes
    return np.eye(num_classes)[labels]


def collate_to_numpy_supervised(batch, latent_dim=None):
    # Convert batch of tensors to NumPy arrays
    batch_data, batch_labels = zip(*batch)
    batch_data = np.array(batch_data)
    batch_labels = one_hot(np.array(batch_labels), latent_dim)
    return batch_labels, batch_data


def collate_to_numpy_unsupervised(batch, latent_dim=None):
    # Convert batch of tensors to NumPy arrays
    batch_data, batch_labels = zip(*batch)
    batch_data = np.array(batch_data)
    batch_labels = np.zeros((batch_data.shape[0], latent_dim))
    return batch_labels, batch_data


class TorchDataloader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        batch_sampler=None,
        num_workers=1,
        pin_memory=True,
        timeout=0,
        worker_init_fn=None,
        persistent_workers=True,
        prefetch_factor=2,
        collate_fn=None,
    ):
        super(self.__class__, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
            drop_last=True if batch_sampler is None else None,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )


def get_datax_cnn(args):
    n_train, n_val, n_test = args.train_size, args.val_size, args.test_size
    dataset = args.dataset
    batch_size = args.batch_size
    is_supervised = args.is_supervised
    latent_dim = args.latent_dim
    is_cnn = False if not hasattr(args, "is_cnn") else args.is_cnn

    if isinstance(args.data_dim, int) or len(args.data_dim) == 1:
        data_shape = (784,)
    else:
        channels = 3 if args.dataset == "cifar10" else 1
        data_shape = (channels, *args.data_dim)

    # Define the transformation to scale pixels to the range [-1, 1]
    cifar_train = [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(32, padding=4, padding_mode="reflect"),  #
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        lambda x: x.numpy(),
    ]
    cifar_eval = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        lambda x: x.numpy(),
    ]
    ist_transform = [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        *(
            [transforms.Resize((32, 32))]
            if is_cnn
            else [transforms.Lambda(lambda x: x.view(-1))]
        ),
        lambda x: x.numpy(),
    ]

    train_transform = transforms.Compose(
        cifar_train if dataset == "cifar10" else ist_transform
    )
    eval_transform = transforms.Compose(
        cifar_eval if dataset == "cifar10" else ist_transform
    )
    if dataset == "mnist":
        train_dataset = datasets.MNIST(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.MNIST(
            "./data", download=True, train=False, transform=eval_transform
        )
    elif dataset == "cifar10":
        train_dataset = datasets.CIFAR10(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.CIFAR10(
            "./data", download=True, train=False, transform=eval_transform
        )
    elif dataset == "fashion_mnist":
        train_dataset = datasets.FashionMNIST(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.FashionMNIST(
            "./data", download=True, train=False, transform=eval_transform
        )
    elif dataset == "cifar100":
        train_dataset = datasets.CIFAR100(
            "./data", download=True, train=True, transform=train_transform
        )
        eval_dataset = datasets.CIFAR100(
            "./data", download=True, train=False, transform=eval_transform
        )
    # Randomly sample the train dataset
    train_dataset = torch.utils.data.Subset(
        train_dataset, random.sample(range(len(train_dataset)), n_train)
    )

    # Randomly sample the val dataset
    # Shuffle eval_dataset in a deterministic way - to prevent leakage, shuffle is needed because second half of eval is more like train than first half
    indices = list(range(len(eval_dataset)))
    random.Random(42).shuffle(indices)
    eval_dataset = Subset(eval_dataset, indices)
    val_dataset, test_dataset, not_used = [
        Subset(eval_dataset, range(start, end))
        for start, end in [
            (0, n_val),
            (n_val, n_val + n_test),
            (n_val + n_test, len(eval_dataset)),
        ]
    ]

    collate_to_numpy = (
        partial(collate_to_numpy_supervised, latent_dim=latent_dim)
        if is_supervised
        else partial(collate_to_numpy_unsupervised, latent_dim=latent_dim)
    )
    train_dl = TorchDataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_to_numpy,
        num_workers=7,
    )
    val_dl = TorchDataloader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_to_numpy,
        num_workers=7,
    )
    test_dl = TorchDataloader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_to_numpy,
        num_workers=7,
    )
    return train_dl, val_dl, test_dl


def make_mean_images(args, output_file=None, mode="val", verbose=False):
    train, val, test = get_datax_cnn(args)
    if mode == "train":
        data = train
    elif mode == "val":
        data = val
    elif mode == "test":
        data = test

    if isinstance(args.data_dim, int) or len(args.data_dim) == 1:
        img_size = (784,)
    else:
        channels = 3 if args.dataset.split("_")[-1] != "mnist" else 1
        img_size = (channels, *args.data_dim)

    data_dl = train if mode == "train" else val if mode == "val" else test
    mean_images = np.zeros((args.latent_dim, *img_size))
    data_dl = tqdm(data_dl, desc="Mean images: ") if verbose else data_dl
    class_counter = torch.zeros(args.latent_dim)
    for X, y in data_dl:
        X = X.argmax(axis=1)
        y = y / 2 + 0.5
        for i in range(X.max() + 1):
            mean_images[i] += y[X == i].sum(axis=0)
            class_counter[i] += (X == i).sum()

    # reshape class_counter to be broadcastable
    class_counter = class_counter.view(-1, *([1] * (mean_images.ndim - 1)))

    mean_images = torch.tensor(mean_images) / class_counter
    if img_size == (784,):
        mean_images = mean_images.view((-1, 28, 28))

    if args.is_cnn:
        mode += "_cnn"

    if output_file is None:
        output_file = f"mean_images/{args.dataset}_{mode}.pt"
    if not os.path.exists("mean_images"):
        os.makedirs("mean_images")
    torch.save(mean_images, output_file)


import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from tqdm import tqdm


def create_grid(x_min=-3.2, x_max=3.3, y_min=-3.2, y_max=3.3, step=0.25):
    x_bins = np.arange(x_min, x_max, step)
    y_bins = np.arange(y_min, y_max, step)
    xx, yy = np.meshgrid(x_bins, y_bins)
    y_grd = np.stack([xx.ravel(), yy.ravel()], axis=1)
    return jnp.array(y_grd), xx, yy


def organize_batches(y_grd, batch_size):
    nm_batches = y_grd.shape[0] // batch_size + 1
    y_grd_dl = [
        y_grd[i * batch_size : (i + 1) * batch_size] for i in range(nm_batches - 1)
    ]
    y_grd_last = y_grd[(nm_batches - 1) * batch_size :]
    y_grd_last = jnp.concatenate(
        [
            y_grd_last,
            jnp.zeros((batch_size - y_grd_last.shape[0], y_grd_last.shape[1])),
        ],
        axis=0,
    )
    y_grd_dl.append(y_grd_last)
    return y_grd_dl


def compute_energy_landscape(
    y_grd,
    y_grd_dl,
    batch_size,
    T_,
    model,
    optim_h,
    init_up=False,
    local_infer_on_batch=infer_on_batch,
):
    x0 = jnp.zeros((batch_size, 1))
    x1 = jnp.ones((batch_size, 1))
    energy_0, energy_1 = [], []
    energy_0_up, energy_1_up = [], []
    energy_0_down, energy_1_down = [], []

    for y_ in tqdm(y_grd_dl):
        if np.any(y_[:, 1] > 2.5):
            5
        # For x=0
        local_infer_on_batch(T_, x0, y_, init_up, 0, model=model, optim_h=optim_h)
        e, e_down, e_up = energy_per_data(x0, y_, model=model)
        energy_0.append(e)
        energy_0_up.append(e_up)
        energy_0_down.append(e_down)

        # For x=1
        local_infer_on_batch(T_, x1, y_, init_up, 0, model=model, optim_h=optim_h)
        e, e_down, e_up = energy_per_data(x1, y_, model=model)
        energy_1.append(e)
        energy_1_up.append(e_up)
        energy_1_down.append(e_down)

    # Concatenate energy values
    energy_0 = jnp.concatenate(energy_0, axis=0)[: y_grd.shape[0]]
    energy_0_up = jnp.concatenate(energy_0_up, axis=0)[: y_grd.shape[0]]
    energy_0_down = jnp.concatenate(energy_0_down, axis=0)[: y_grd.shape[0]]
    energy_1 = jnp.concatenate(energy_1, axis=0)[: y_grd.shape[0]]
    energy_1_up = jnp.concatenate(energy_1_up, axis=0)[: y_grd.shape[0]]
    energy_1_down = jnp.concatenate(energy_1_down, axis=0)[: y_grd.shape[0]]

    return energy_0, energy_0_up, energy_0_down, energy_1, energy_1_up, energy_1_down


def plot_energy_landscape(
    xx, yy, energy, y, x, title, ax, color="k", plot_data=True, energy_threshold=None
):
    num_levels = 20
    levels = np.logspace(
        np.log10(jnp.nanmin(energy)), np.log10(jnp.nanmax(energy)), num_levels
    )
    if energy_threshold is not None:
        levels = np.concatenate(
            ([levels[0]], [energy_threshold], levels[levels > energy_threshold])
        )
    cs = ax.contourf(
        xx, yy, energy.reshape(xx.shape), levels=levels, norm=mcolors.LogNorm()
    )
    if plot_data:
        ax.scatter(y[x[:, 0] == 0, 0], y[x[:, 0] == 0, 1], c=color, s=2, alpha=0.1)
        ax.scatter(y[x[:, 0] == 1, 0], y[x[:, 0] == 1, 1], c="grey", s=2, alpha=0.4)
    ax.set_title(title)
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_xticks([])
    ax.set_yticks([])


def main_energy_landscape_function(
    model,
    x,
    y,
    batch_size,
    optim_h,
    T_=10000,
    init_up=False,
    plot=True,
    plot_data=True,
    energy_threshold=None,
    local_infer_on_batch=infer_on_batch,
):
    # Step 1: Create grid and organize into batches
    y_grd, xx, yy = create_grid()
    y_grd_dl = organize_batches(y_grd, batch_size)

    # Step 2: Compute energy landscape
    energy_0, energy_0_down, energy_0_up, energy_1, energy_1_down, energy_1_up = (
        compute_energy_landscape(
            y_grd,
            y_grd_dl,
            batch_size,
            T_,
            model,
            optim_h,
            init_up,
            local_infer_on_batch,
        )
    )

    # Step 3: Plot energy landscape
    rows = 3 if model.alpha_down != 0 and model.alpha_up != 0 else 2
    fig, axs = plt.subplots(rows, 2, sharex=True, sharey=True)
    plot_energy_landscape(
        xx,
        yy,
        energy_0,
        y,
        x,
        "Energy label 0",
        axs[0, 0],
        plot_data=plot_data,
        energy_threshold=energy_threshold,
    )
    plot_energy_landscape(
        xx,
        yy,
        energy_1,
        y,
        x,
        "Energy label 1",
        axs[0, 1],
        plot_data=plot_data,
        energy_threshold=energy_threshold,
    )
    if model.alpha_up != 0:
        plot_energy_landscape(
            xx, yy, energy_0_up, y, x, "", axs[1, 0], plot_data=plot_data
        )
        plot_energy_landscape(
            xx, yy, energy_1_up, y, x, "", axs[1, 1], plot_data=plot_data
        )
        axs[1, 0].set_ylabel("Energy up")
    if model.alpha_down != 0:
        idx = 2 if model.alpha_up != 0 else 1
        plot_energy_landscape(
            xx, yy, energy_0_down, y, x, "", axs[idx, 0], plot_data=plot_data
        )
        plot_energy_landscape(
            xx, yy, energy_1_down, y, x, "", axs[idx, 1], plot_data=plot_data
        )
        axs[idx, 0].set_ylabel("Energy down")
    if plot:
        plt.show()
    return (
        fig,
        axs,
        energy_0,
        energy_1,
        energy_0_up,
        energy_1_up,
        energy_0_down,
        energy_1_down,
    )
