import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset, random_split

import numpy
import time
import math

from tqdm import tqdm


def get_beta_schedule(beta_schedule_args):
    if beta_schedule_args["type"] == "linear":
        return torch.linspace(
            beta_schedule_args["min"],
            beta_schedule_args["max"],
            beta_schedule_args["timesteps"],
        )
    elif beta_schedule_args["type"] == "cosine":
        # Cosine schedule following Nichol & Dhariwal (2021)
        # Ignores min and max
        steps = torch.arange(beta_schedule_args["timesteps"], dtype=torch.float32)
        f = (
            lambda t: torch.cos(
                (t / beta_schedule_args["timesteps"] + 0.008) / 1.008 * torch.pi / 2
            )
            ** 2
        )
        alphas = f(steps) / f(torch.tensor(0.0))
        return 1 - alphas
    elif beta_schedule_args["type"] == "quadratic":
        return (
            torch.linspace(
                beta_schedule_args["min"] ** 0.5,
                beta_schedule_args["max"] ** 0.5,
                beta_schedule_args["timesteps"],
            )
            ** 2
        )
    else:
        raise ValueError("Unkown beta_schedule type")


def sinusoidal_embedding(t, dim):
    """Generate sinusoidal embeddings for the timestep t."""
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) * -emb)
    emb = t.unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
    return emb


class ConditionalDiffusionModel(nn.Module):
    def __init__(
        self, input_dim, condition_dim, beta_schedule_args=None, layer_sizes=[128, 64]
    ):
        super().__init__()
        self.timesteps = beta_schedule_args["timesteps"]

        self.register_buffer("beta", get_beta_schedule(beta_schedule_args))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

        self.t_emb_dim = 4

        if condition_dim == 0:
            print("Conditioning dim is 0")

        self.t_embedding_layer = nn.Sequential(
            nn.Linear(
                self.t_emb_dim, 32
            ),  # Expand time embedding to a higher dimension
            nn.ReLU(),
            nn.Linear(32, self.t_emb_dim),  # Map back to the same dimension if needed
        )

        self.conditioning_network = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, condition_dim),
        )

        network_layers = []
        prev_size = (
            input_dim + condition_dim + self.t_emb_dim
        )  # Input size including condition and timestep

        for size in layer_sizes:
            network_layers.extend(
                [
                    nn.Linear(prev_size, size),
                    nn.ReLU(),
                ]
            )
            prev_size = size

        # Add final output layer
        network_layers.append(nn.Linear(prev_size, input_dim))

        self.network = nn.Sequential(*network_layers)

    def forward(self, x, c, t):
        t_emb = self.t_embedding_layer(sinusoidal_embedding(t, self.t_emb_dim))
        c_emb = self.conditioning_network(c.unsqueeze(1))
        x_input = torch.cat([x, c_emb, t_emb], dim=1)
        return self.network(x_input)


def train_diffusion(
    model,
    dataset,
    epochs=5000,
    lr=1e-4,
    batch_size=512,
    validation_split=0.1,
    callbacks=None,
):
    optimizer = Adam(model.parameters(), lr=lr)
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using mps")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    model.to(device)

    losses = []

    # Split into training and validation
    val_size = int(len(dataset) * validation_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    validation_losses = []

    start_time = time.time()
    best_val_loss = float("inf")

    for epoch in tqdm(range(epochs)):
        for x, c in train_loader:
            x = x.unsqueeze(1)  # Shape: [batch_size, 1]
            x, c = x.to(device), c.to(device)
            t = torch.randint(0, model.timesteps, (x.size(0),)).to(device)

            noise = torch.randn_like(x)  # Shape: [batch_size, 1]
            alpha_bar_t = model.alpha_bar[t].view(-1, 1)  # Shape: [batch_size, 1]

            # Ensure proper broadcasting for noise addition
            x_t = (
                torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
            )  # Shape: [batch_size, 1]

            pred_noise = model(x_t, c, t)
            loss = F.mse_loss(pred_noise, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
        losses.append(loss.item())

        do_validation = epoch % 50 == 0 or epochs == epochs - 1
        if do_validation:
            model.eval()
            val_epoch_losses = []

            with torch.no_grad():
                for x, c in val_loader:
                    x = x.unsqueeze(1)  # Shape: [batch_size, 1]
                    x, c = x.to(device), c.to(device)
                    t = torch.randint(0, model.timesteps, (x.size(0),)).to(device)

                    noise = torch.randn_like(x)  # Shape: [batch_size, 1]
                    alpha_bar_t = model.alpha_bar[t].view(
                        -1, 1
                    )  # Shape: [batch_size, 1]

                    # Ensure proper broadcasting for noise addition
                    x_t = (
                        torch.sqrt(alpha_bar_t) * x
                        + torch.sqrt(1 - alpha_bar_t) * noise
                    )  # Shape: [batch_size, 1]

                    pred_noise = model(x_t, c, t)
                    val_epoch_losses.append(F.mse_loss(pred_noise, noise).item())

            val_loss = numpy.mean(val_epoch_losses)
            validation_losses.append(val_loss)

            # Check if this is the best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss

            elapsed = time.time() - start_time
            print(
                f"Epoch {epoch}/{epochs} - Loss: {losses[-1]:.4f} - Val Loss: {val_loss:.4f} - Time: {elapsed:.1f}s"
            )

            # Call callbacks
            if callbacks and "epoch_end" in callbacks:
                metrics = {
                    "train_loss": losses[-1],
                    "validation_loss": val_loss,
                    "best_validation_loss": best_val_loss,
                    "epoch": epoch,
                }
                callbacks["epoch_end"](epoch, model, metrics)

    return {"loss": loss.item(), "min_loss": min(losses), "losses": losses, "validation_losses": validation_losses, "min_validation_loss": best_val_loss}


def evaluate_diffusion_loss(model, dataloader):
    device = torch.device(
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    model.to(device)
    model.eval()

    total_loss = 0
    num_batches = 0
    with torch.no_grad():
        for x, c in dataloader:
            x = x.unsqueeze(1)
            x, c = x.to(device), c.to(device)
            t = torch.randint(0, model.timesteps, (x.size(0),)).to(device)

            noise = torch.randn_like(x)
            alpha_bar_t = model.alpha_bar[t].view(-1, 1)
            x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise

            pred_noise = model(x_t, c, t)
            loss = F.mse_loss(pred_noise, noise)
            total_loss += loss.item()
            num_batches += 1

    return total_loss / num_batches
