from sched import scheduler
from sklearn.feature_selection import mutual_info_regression
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.data import ConcatDataset


import numpy
import time
import math
import os
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from datetime import datetime

import numpy as np

from tqdm import tqdm

from typing import Protocol

from mi_estimators.mi_mine import CriticNetwork, train_mine
from models.get_loss_function import get_loss_function


def sinusoidal_embedding(t, dim):
    """Generate sinusoidal embeddings for the timestep t."""
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) * -emb)
    emb = t.unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
    return emb


def get_beta_schedule(beta_schedule_args):
    if beta_schedule_args["type"] == "linear":
        return torch.linspace(
            beta_schedule_args["min"],
            beta_schedule_args["max"],
            beta_schedule_args["timesteps"],
        )
    elif beta_schedule_args["type"] == "cosine":
        # Cosine schedule following Nichol & Dhariwal (2021)
        # Ignores min and max
        steps = torch.arange(beta_schedule_args["timesteps"], dtype=torch.float32)
        f = (
            lambda t: torch.cos(
                (t / beta_schedule_args["timesteps"] + 0.008) / 1.008 * torch.pi / 2
            )
            ** 2
        )
        alphas = f(steps) / f(torch.tensor(0.0))
        return 1 - alphas
    elif beta_schedule_args["type"] == "quadratic":
        return (
            torch.linspace(
                beta_schedule_args["min"] ** 0.5,
                beta_schedule_args["max"] ** 0.5,
                beta_schedule_args["timesteps"],
            )
            ** 2
        )
    else:
        raise ValueError("Unkown beta_schedule type")


def evaluate_diffusion_loss(model, test_loss, dataloader):
    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    model.to(device)
    model.eval()

    test_loss_fn = get_loss_function(test_loss)

    total_loss = 0
    num_batches = 0
    with torch.no_grad():
        for x, c in dataloader:
            x = x.unsqueeze(1)
            x, c = x.to(device), c.to(device)
            t = torch.randint(0, model.timesteps, (x.size(0),)).to(device)

            noise = torch.randn_like(x)
            alpha_bar_t = model.alpha_bar[t].view(-1, 1)
            x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise

            pred_noise = model(x_t, c, t)
            loss, _ = test_loss_fn(pred_noise, noise, c, {})
            total_loss += loss.item()
            num_batches += 1

    return total_loss / num_batches


def evaluate_full_diffusion_loss(
    model,
    test_loss,
    dataloader,
    prediction_save_path=None,
    add_artifact=None,
):
    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(prediction_save_path)
    model.to(device)
    model.eval()

    test_loss_fn = get_loss_function(test_loss)

    total_loss = 0
    num_batches = 0
    full_logging_info_per_timestep = []
    with torch.no_grad():
        full_logging_info = {}
        for timestep in tqdm(range(model.timesteps)):
            num_batches = 0
            for x, c in dataloader:
                x = x.unsqueeze(1)
                x, c = x.to(device), c.to(device)
                t = torch.full((x.size(0),), timestep).to(device)

                noise = torch.randn_like(x)
                alpha_bar_t = model.alpha_bar[t].view(-1, 1)
                x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise

                pred_noise = model(x_t, c, t)

                # Save the predictions
                if prediction_save_path is not None:
                    if not os.path.exists(prediction_save_path):
                        os.makedirs(prediction_save_path)
                    path = f"{prediction_save_path}_timestep_{timestep}_batch_{num_batches}.pt"
                    torch.save(
                        {
                            "pred_noise": pred_noise.cpu(),
                            "noise": noise.cpu(),
                            "c": c.cpu(),
                            "x": x.cpu(),
                        },
                        path,
                    )
                    add_artifact(path)

                loss, logging_info = test_loss_fn(pred_noise, noise, c, {})
                total_loss += loss.item()
                num_batches += 1

                for key, value in logging_info.items():
                    if value is None:
                        continue
                    if key not in full_logging_info:
                        full_logging_info[key] = 0.0
                    full_logging_info[key] += value

            for key, value in full_logging_info.items():
                full_logging_info[key] /= num_batches
            full_logging_info_per_timestep.append(
                {**full_logging_info, "timestep": timestep}
            )

    return full_logging_info_per_timestep


def get_training_schedule(
    epoch, epochs, training_hyperparameters, previous_parameters=None
):
    min_beta = training_hyperparameters.get("min_beta", 0.1)
    max_beta = training_hyperparameters.get("max_beta", 1.0)

    if training_hyperparameters["type"] == "linear":
        return {"beta": min_beta + (max_beta - min_beta) * (epoch / epochs)}
    elif training_hyperparameters["type"] == "cyclical":
        tau = ((epoch - 1) % (epochs / training_hyperparameters["M"])) / (
            epochs / training_hyperparameters["M"]
        )
        if tau <= training_hyperparameters["R"]:
            wu_beta = min_beta + (max_beta - min_beta) * (
                tau / training_hyperparameters["R"]
            )
        else:
            wu_beta = max_beta
        return {"beta": wu_beta}
    elif training_hyperparameters["type"] == "constant":
        return {"beta": max_beta}

    if training_hyperparameters["type"] == "delayed_linear":
        delay_fraction = training_hyperparameters.get("R", 0.2)
        repetitions = training_hyperparameters.get("M", 1)
        hold_fraction = training_hyperparameters.get("hold_fraction", 0.2)

        period_length = epochs // repetitions
        epoch_in_period = epoch % period_length

        delay_epochs = int(period_length * delay_fraction)
        hold_epochs = int(period_length * hold_fraction)

        # Remaining part for the linear increase
        linear_epochs = period_length - delay_epochs - hold_epochs

        if epoch_in_period <= delay_epochs:
            return {"beta": min_beta}
        elif epoch_in_period <= delay_epochs + linear_epochs:
            adjusted_epoch = epoch_in_period - delay_epochs
            beta = min_beta + (max_beta - min_beta) * (
                float(adjusted_epoch) / linear_epochs
            )
            return {"beta": beta}
        else:
            # Hold at max_beta
            return {"beta": max_beta}

    elif training_hyperparameters["type"] in ["adaptive", "gamma_gda"]:
        if previous_parameters is None:
            return {"beta": training_hyperparameters["start_beta"]}
        else:
            return previous_parameters
    else:
        raise ValueError(
            f"Unsupported training hyperparameter type: {training_hyperparameters['type']}"
        )


def get_optimizer(
    model,
    optimizer_settings,
    lr,
    total_steps,
):
    optimizer_type = optimizer_settings["optimizer"]
    scheduler_type = optimizer_settings["scheduler"]
    scheduler_args = optimizer_settings["scheduler_args"]

    if optimizer_type == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optimizer_type == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    elif optimizer_type == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    if scheduler_type == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=total_steps, eta_min=scheduler_args["eta_min"]
        )
    elif scheduler_type == "linear":
        scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=0.0,
            total_iters=total_steps,
        )
    else:
        scheduler = None

    return optimizer, scheduler


def train_diffusion(
    model,
    train_loss_type,
    test_loss_type,
    train_loss_schedule,
    dataset,
    epochs=5000,
    lr=1e-4,
    batch_size=4048,
    validation_split=0.9,
    callbacks=None,
    optimizer_settings=None,
):
    optimizer, scheduler = get_optimizer(
        model,
        optimizer_settings,
        lr,
        epochs * len(dataset) // min(batch_size, len(dataset)),
    )

    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using mps")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    if train_loss_schedule["type"] == "gamma_gda":
        l = torch.tensor([0.0], device=device, requires_grad=True)
        optimizer_l = Adam([l], lr=0.1)

    losses = []
    train_logging_infos = []

    # Split into training and validation
    val_size = int(len(dataset) * validation_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    validation_losses = []
    validation_logging_infos = []

    train_loss_function = get_loss_function(train_loss_type)
    test_loss_function = get_loss_function(test_loss_type)

    start_time = time.time()
    best_val_loss = float("inf")

    if train_loss_type in ["mine_mse", "mi_mse_debug"]:
        mine_network = CriticNetwork(input_dim=2, hidden_dim=512).to(device)
        mine_opt = Adam(mine_network.parameters(), lr=1e-3)
    else:
        mine_network = None
        mine_opt = None

    maximize_step = False

    parameters = None
    for epoch in tqdm(range(epochs)):
        parameters = get_training_schedule(
            epoch, epochs, train_loss_schedule, parameters
        )
        epoch_losses = []
        epoch_logging_info = {}
        num_batches = 0

        for x, c in train_loader:
            x = x.unsqueeze(1)  # Shape: [batch_size, 1]
            x, c = x.to(device), c.to(device)
            t = torch.randint(0, model.timesteps, (x.size(0),)).to(device)

            noise = torch.randn_like(x)  # Shape: [batch_size, 1]
            alpha_bar_t = model.alpha_bar[t].view(-1, 1)  # Shape: [batch_size, 1]

            # Ensure proper broadcasting for noise addition
            x_t = (
                torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
            )  # Shape: [batch_size, 1]

            pred_noise = model(x_t, c, t)

            # print(f"x_t: {x_t.shape}, c: {c.unsqueeze(1).shape}, t: {t.shape}")
            # print(f"pred_noise: {pred_noise.shape}, noise: {noise.shape}")
            if mine_network is not None:
                mine_network.train()
                train_mine(
                    pred_noise.detach(),
                    c.unsqueeze(1).detach(),
                    mine_network,
                    mine_opt,
                    epochs=2,
                )
                mine_network.eval()

            optimizer.zero_grad()
            if train_loss_schedule["type"] == "gamma_gda":
                optimizer_l.zero_grad()
                parameters["beta"] = l

            loss, logging_info = train_loss_function(
                pred_noise,
                noise,
                c,
                parameters,
                mine_network=mine_network,
            )

            if train_loss_schedule["type"] == "adaptive":
                if logging_info["mi_estimate"] > train_loss_schedule["mi_threshold"]:
                    parameters["beta"] *= train_loss_schedule["adaptive_factor"]
                else:
                    parameters["beta"] /= train_loss_schedule["adaptive_factor"]
            elif train_loss_schedule["type"] == "gamma_gda":
                optimizer_l.zero_grad()

            parameters["beta"] = max(
                min(parameters["beta"], train_loss_schedule["max_beta"]),
                train_loss_schedule["min_beta"],
            )

            # Accumulate logging info for averaging
            for key, value in logging_info.items():
                if value is None:
                    continue
                if key not in epoch_logging_info:
                    epoch_logging_info[key] = 0.0
                epoch_logging_info[key] += value

            # Store loss for this batch
            epoch_losses.append(loss.item())
            num_batches += 1

            if torch.isnan(loss):
                print("Loss is NaN, skipping batch.")
                continue

            if train_loss_schedule["type"] == "gamma_gda":
                if maximize_step:
                    (-loss).backward()
                    optimizer_l.step()
                else:
                    loss.backward()
                    optimizer.step()
                    if scheduler:
                        scheduler.step()
                maximize_step = not maximize_step
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

        # Average the logging info over all batches in the epoch
        for key in epoch_logging_info:
            epoch_logging_info[key] /= num_batches

        # Add parameters to the logging info
        if train_loss_schedule["type"] == "gamma_gda":
            epoch_logging_info.update({"beta": parameters["beta"].item()})
        else:
            epoch_logging_info.update({"beta": parameters["beta"]})
        train_logging_infos.append(epoch_logging_info)

        # Average loss for the epoch
        epoch_loss = sum(epoch_losses) / len(epoch_losses)
        losses.append(epoch_loss)

        do_validation = epoch % 50 == 0 or epoch == epochs - 1
        val_loss = float("inf")

        if do_validation:
            model.eval()
            val_epoch_losses = []

            # Print gradient norms for each parameter
            # if epoch % 100 == 0:
            #     for name, param in model.named_parameters():
            #         grad_norm = (
            #             param.grad.norm().item() if param.grad is not None else None
            #         )
            #         print(f"{name}: grad norm={grad_norm}")

            #     if train_loss_schedule["type"] == "gamma_gda":
            #         print(f"l: {l.item()}, {l.grad.norm().item()}")

            with torch.no_grad():
                for x, c in val_loader:
                    x = x.unsqueeze(1)  # Shape: [batch_size, 1]
                    x, c = x.to(device), c.to(device)
                    t = torch.randint(0, model.timesteps, (x.size(0),)).to(device)

                    noise = torch.randn_like(x)  # Shape: [batch_size, 1]
                    alpha_bar_t = model.alpha_bar[t].view(
                        -1, 1
                    )  # Shape: [batch_size, 1]

                    # Ensure proper broadcasting for noise addition
                    x_t = (
                        torch.sqrt(alpha_bar_t) * x
                        + torch.sqrt(1 - alpha_bar_t) * noise
                    )  # Shape: [batch_size, 1]

                    pred_noise = model(x_t, c, t)

                    callbacks["store_validation_prediction"](
                        epoch, torch.stack([pred_noise, noise, c.unsqueeze(1)], dim=0)
                    )

                    loss, logging_info = test_loss_function(
                        pred_noise, noise, c, parameters
                    )
                    val_epoch_losses.append(loss.item())
                    if train_loss_schedule["type"] == "gamma_gda":
                        validation_logging_infos.append(
                            {**logging_info, "beta": parameters["beta"].item()}
                        )
                    else:
                        validation_logging_infos.append({**logging_info, **parameters})

            val_loss = numpy.mean(val_epoch_losses)
            validation_losses.append(float(val_loss))

            # Check if this is the best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss

            str = f" Epoch {epoch}/{epochs} - Loss: {losses[-1]:.2f} - Val Loss: {val_loss:.2f}"

            if train_loss_schedule["type"] == "gamma_gda":
                str += f" - Beta: {parameters['beta'].item():.4f}"
            else:
                str += f" - Beta: {parameters['beta']:.4f}"

            if "mse" in epoch_logging_info:
                str += f" - MSE: {epoch_logging_info['mse']:.2f}"
            for key, value in epoch_logging_info.items():
                if key not in ["mse"]:
                    str += f" - {key}: {value:.4f}"
            print(str)

        # Call callbacks
        if callbacks and "epoch_end" in callbacks:
            metrics = {
                "train_loss": losses[-1],
                "validation_loss": val_loss if val_loss is not None else None,
                "best_validation_loss": best_val_loss,
                "epoch": epoch,
                "logging_info": epoch_logging_info,
            }
            callbacks["epoch_end"](epoch, model, metrics)

    return {
        "loss": loss.item(),
        "min_loss": min(losses),
        "losses": losses,
        "train_logging_infos": train_logging_infos,
        "validation_losses": validation_losses,
        "validation_logging_infos": validation_logging_infos,
        "min_validation_loss": best_val_loss,
    }
