import torch
import numpy as np
import random
import os
import csv
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torchvision.utils as vutils
from torch.utils.data import Dataset


# ---------------------------
# Random Seed Setup
# ---------------------------
def set_random_seed(seed, use_cuda=True):
    """
    Set the random seed for reproducibility.
    
    Args:
        seed (int): The seed value.
        use_cuda (bool): Whether to set the CUDA seed.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if use_cuda:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f"Random seed set to: {seed}")


# ---------------------------
# Dataset Definition
# ---------------------------
class GaussianMixtureDataset(Dataset):
    """
    DataLoader for a Gaussian Mixture dataset.
    
    Generates data samples from a mixture of Gaussians where the means are the 
    cartesian product of [-1, 1] for each dimension.
    """
    def __init__(self, num_samples=12800, num_dimensions=3):
        self.num_samples = num_samples
        self.num_dimensions = num_dimensions
        self.means_combinations = torch.cartesian_prod(
            *(torch.tensor([1, -1]) for _ in range(num_dimensions))
        ).float()
        self.data = self.generate_data(num_samples, num_dimensions)

    def generate_data(self, num_samples, num_dimensions):
        """
        Generate samples for each Gaussian component.
        
        Args:
            num_samples (int): Total number of samples.
            num_dimensions (int): Dimensionality of the data.
        
        Returns:
            Tensor: Concatenated samples from each Gaussian.
        """
        samples = []
        n = len(self.means_combinations)
        samples_per_component = int(num_samples / n)
        
        for i in range(n):
            # Create a multivariate normal distribution for each mean
            distribution = torch.distributions.MultivariateNormal(
                loc=self.means_combinations[i],
                covariance_matrix=0.0125 * torch.eye(num_dimensions)
            )
            samples_temp = distribution.sample((samples_per_component,))
            samples.append(samples_temp)
        
        concatenated_samples = torch.cat(samples, dim=0)
        return concatenated_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]


# ---------------------------
# Plotting Functions
# ---------------------------
def plot_duality_gap_and_frobenius(record_largest_frobenius_norm, record_steepness, DG, epochs, threshold, save_path):
    """
    Plot Duality Gap and Frobenius Norm on a dual-axis plot.

    Args:
        record_largest_frobenius_norm (list): Frobenius norm values.
        record_steepness (list): Steepness values.
        DG (dict): Dictionary containing duality gap metrics ('vanilla' and 'local_random').
        epochs (list): List of epoch indices.
        threshold (float): Threshold value for the Frobenius norm.
        save_path (str): Path to save the generated plot.
    """
    fig, ax1 = plt.subplots(figsize=(12, 9))

    # Plot Frobenius Norm on the first y-axis
    ax1.plot(epochs, record_largest_frobenius_norm, color="tab:blue", linestyle="-",
             label=r"90% Percentile of $||\nabla d(x)/d(x)||_2$")
    ax1.set_xlabel("Epochs", fontsize=30)
    ax1.set_ylabel(r"$||\nabla d(x)/d(x)||_2$", color="tab:blue", fontsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.set_ylim([0, 5000])
    ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax1.yaxis.get_offset_text().set_fontsize(20)

    # Plot threshold line
    ax1.axhline(y=threshold, color="black", linestyle="-",
                label=rf"$||\nabla d(x)/d(x)||_2={threshold}$")

    # Find first occurrence of threshold exceedance or steepness drop
    index_frobenius = next((i for i, v in enumerate(record_largest_frobenius_norm) if v >= threshold), None)
    index_steepness = next((i for i in range(1, len(record_steepness)) if record_steepness[i] <= 0.5 * record_steepness[i - 1]), None)
    candidates = list(filter(lambda x: x is not None, [index_frobenius, index_steepness]))
    index_of_first = min(candidates) if candidates else None

    if index_of_first is not None:
        ax1.annotate(f"Epoch: {index_of_first}", (index_of_first, threshold),
                     textcoords="offset points", ha="center", xytext=(1, 10), fontsize=30, zorder=10)
        ax1.plot(index_of_first, threshold, "xr", zorder=10)

    # Plot Duality Gap on the second y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, DG["vanilla"], color="tab:red", linestyle="-", marker="^", label="Vanilla duality gap")
    ax2.plot(DG["local_random"], color="tab:red", linestyle="--", marker="o", label="Perturbed duality gap")
    ax2.set_ylabel("Duality gap", color="tab:red", fontsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red")
    ax2.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax2.yaxis.get_offset_text().set_fontsize(20)

    ax1.tick_params(axis="x", labelsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=30)

    plt.tight_layout()
    plt.savefig(save_path, transparent=True)
    plt.close()


def plot_duality_gap_and_steepness(record_steepness, DG, epochs, save_path):
    """
    Plot Duality Gap and Steepness on a dual-axis plot.

    Args:
        record_steepness (list): Steepness values.
        DG (dict): Dictionary containing duality gap metrics.
        epochs (list): List of epoch indices.
        save_path (str): Path to save the plot.
    """
    fig, ax1 = plt.subplots(figsize=(12, 9))

    # Plot steepness on the first y-axis
    ax1.plot(epochs, record_steepness, color="tab:blue", linestyle="-",
             label="Steepness")
    ax1.set_ylim(bottom=0)
    ax1.set_xlabel("Epochs", fontsize=30)
    ax1.set_ylabel("Steepness", color="tab:blue", fontsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax1.yaxis.get_offset_text().set_fontsize(20)

    # Plot duality gap on the second y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, DG["vanilla"], color="tab:red", linestyle="-", marker="^", label="Vanilla duality gap")
    ax2.plot(DG["local_random"], color="tab:red", linestyle="--", marker="o", label="Perturbed duality gap")
    ax2.set_ylabel("Duality gap", color="tab:red", fontsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red")
    ax2.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax2.yaxis.get_offset_text().set_fontsize(20)

    ax1.tick_params(axis="x", labelsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=30)

    plt.tight_layout()
    plt.savefig(save_path, transparent=True)
    plt.close()


def plot_fid_and_frobenius(record_largest_frobenius_norm, record_steepness, fid_scores, epochs, threshold, save_path):
    """
    Plot FID Score and Frobenius Norm on a dual-axis plot.

    Args:
        record_largest_frobenius_norm (list): Frobenius norm values.
        record_steepness (list): Steepness values.
        fid_scores (list): FID scores.
        epochs (list): List of epoch indices.
        threshold (float): Threshold value.
        save_path (str): Path to save the plot.
    """
    fig, ax1 = plt.subplots(figsize=(12, 9))

    # Plot Frobenius Norm on the first y-axis
    ax1.plot(epochs, record_largest_frobenius_norm, color="tab:blue", linestyle="-",
             label=r"90% Percentile of $||\nabla d(x)/d(x)||_2$")
    ax1.set_xlabel("Epochs", fontsize=30)
    ax1.set_ylabel(r"$||\nabla d(x)/d(x)||_2$", color="tab:blue", fontsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.set_ylim([0, 5000])
    ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax1.yaxis.get_offset_text().set_fontsize(20)

    # Plot threshold line
    ax1.axhline(y=threshold, color="black", linestyle="-",
                label=rf"$||\nabla d(x)/d(x)||_2={threshold}$")

    # Determine the first occurrence of threshold exceedance or steepness drop
    index_frobenius = next((i for i, v in enumerate(record_largest_frobenius_norm) if v >= threshold), None)
    index_steepness = next((i for i in range(1, len(record_steepness)) if record_steepness[i] <= 0.5 * record_steepness[i - 1]), None)
    candidates = list(filter(lambda x: x is not None, [index_frobenius, index_steepness]))
    index_of_first = min(candidates) if candidates else None

    if index_of_first is not None:
        ax1.annotate(f"Epoch: {index_of_first}", (index_of_first, threshold),
                     textcoords="offset points", ha="center", xytext=(1, 10), fontsize=30, zorder=10)
        ax1.plot(index_of_first, threshold, "xr", zorder=10)

    # Plot FID Score on the second y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, fid_scores, color="tab:red", linestyle="--", marker="o", label="FID score")
    ax2.set_ylabel("FID score", color="tab:red", fontsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red")
    ax2.set_ylim([0, max(fid_scores) + 100])
    ax2.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax2.yaxis.get_offset_text().set_fontsize(20)

    ax1.tick_params(axis="x", labelsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=30)

    plt.tight_layout()
    plt.savefig(save_path, transparent=True)
    plt.close()


def plot_steepness_and_frobenius(record_largest_frobenius_norm, record_steepness, epochs, threshold, save_path):
    """
    Plot Steepness and Frobenius Norm on a dual-axis plot.

    Args:
        record_largest_frobenius_norm (list): Frobenius norm values.
        record_steepness (list): Steepness values.
        epochs (list): List of epoch indices.
        threshold (float): Threshold value.
        save_path (str): Path to save the plot.
    """
    fig, ax1 = plt.subplots(figsize=(12, 9))

    # Plot Frobenius norm on the first y-axis
    ax1.plot(epochs, record_largest_frobenius_norm, label=r"90% Percentile of $||\nabla d(x)/d(x)||_2$",
             color="tab:blue")
    ax1.set_xlabel("Epochs", fontsize=30)
    ax1.set_ylabel(r"$||\nabla d(x)/d(x)||_2$", color="tab:blue", fontsize=30)
    ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax1.yaxis.get_offset_text().set_fontsize(20)
    ax1.axhline(y=threshold, color="black", linestyle="-",
                label=rf"$||\nabla d(x)/d(x)||_2={threshold}$")

    # Find first occurrence of threshold exceedance or steepness drop
    index_frobenius = next((i for i, v in enumerate(record_largest_frobenius_norm) if v >= threshold), None)
    index_steepness = next((i for i in range(1, len(record_steepness)) if record_steepness[i] <= 0.5 * record_steepness[i - 1]), None)
    candidates = list(filter(lambda x: x is not None, [index_frobenius, index_steepness]))
    index_of_first = min(candidates) if candidates else None

    if index_of_first is not None:
        ax1.annotate(f"Epoch: {index_of_first}", (index_of_first, threshold),
                     textcoords="offset points", ha="center", xytext=(1, 10), fontsize=30, zorder=20)
        ax1.plot(index_of_first, threshold, "xr", zorder=10)

    # Plot steepness on the second y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, record_steepness, linestyle="--", marker="o", label="Steepness", color="tab:red")
    ax2.set_ylim(bottom=0)
    ax2.set_ylabel("Steepness", color="tab:red", fontsize=30)
    ax2.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax2.yaxis.get_offset_text().set_fontsize(20)

    ax1.tick_params(axis="x", labelsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=30)

    plt.tight_layout()
    plt.savefig(save_path, transparent=True)
    plt.close()


def plot_fid_and_steepness(record_steepness, fid_scores, epochs, save_path):
    """
    Plot FID Score and Steepness on a dual-axis plot.

    Args:
        record_steepness (list): Steepness values.
        fid_scores (list): FID scores.
        epochs (list): List of epoch indices.
        save_path (str): Path to save the plot.
    """
    fig, ax1 = plt.subplots(figsize=(12, 9))

    # Plot steepness on the first y-axis
    ax1.plot(epochs, record_steepness, color="tab:blue", linestyle="-",
             label="Steepness")
    ax1.set_ylim(bottom=0)
    ax1.set_xlabel("Epochs", fontsize=30)
    ax1.set_ylabel("Steepness", color="tab:blue", fontsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax1.yaxis.get_offset_text().set_fontsize(20)

    # Plot FID Score on the second y-axis
    ax2 = ax1.twinx()
    ax2.plot(epochs, fid_scores, color="tab:red", linestyle="--", marker="o", label="FID score")
    ax2.set_ylabel("FID score", color="tab:red", fontsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red")
    ax2.set_ylim([0, max(fid_scores) + 100])
    ax2.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    ax2.yaxis.get_offset_text().set_fontsize(20)

    ax1.tick_params(axis="x", labelsize=30)
    ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=30)
    ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=30)

    plt.tight_layout()
    plt.savefig(save_path, transparent=True)
    plt.close()


# ---------------------------
# File and CSV Saving Functions
# ---------------------------
def save_real_images_to_folder(dataloader, output_folder, num_images):
    """
    Save real images from a dataloader to the specified folder.

    Args:
        dataloader (torch.utils.data.DataLoader): DataLoader for the dataset.
        output_folder (str): Destination folder for saving images.
        num_images (int): Number of images to save.
    """
    os.makedirs(output_folder, exist_ok=True)
    saved_images = 0

    with torch.no_grad():
        for data in dataloader:
            real_imgs = data[0]
            batch_size = real_imgs.size(0)
            for j in range(batch_size):
                if saved_images >= num_images:
                    return
                image_path = os.path.join(output_folder, f"real_{saved_images}.jpg")
                vutils.save_image(real_imgs[j], image_path, normalize=True)
                saved_images += 1


def save_metrics_to_csv(epoch, fid_score_value, largest_frobenius_norm, steepness_value,
                        d_loss, g_loss, vanilla_dg, local_random_dg, save_path):
    """
    Save metrics to a CSV file.

    Args:
        epoch (int): Epoch number.
        fid_score_value (float): FID score.
        largest_frobenius_norm (float): Frobenius norm value.
        steepness_value (float): Steepness value.
        d_loss (float): Discriminator loss.
        g_loss (float): Generator loss.
        vanilla_dg (float): Vanilla duality gap.
        local_random_dg (float): Perturbed duality gap.
        save_path (str): Path to the CSV file.
    """
    # Convert tensor values to plain numbers if needed
    for var in [("fid_score_value", fid_score_value), ("largest_frobenius_norm", largest_frobenius_norm),
                ("steepness_value", steepness_value), ("d_loss", d_loss), ("g_loss", g_loss),
                ("vanilla_dg", vanilla_dg), ("local_random_dg", local_random_dg)]:
        if torch.is_tensor(var[1]):
            exec(f"{var[0]} = var[1].item()")

    with open(save_path, mode="a", newline="") as file:
        writer = csv.writer(file)
        writer.writerow([epoch, fid_score_value, largest_frobenius_norm, steepness_value,
                         d_loss, g_loss, vanilla_dg, local_random_dg])


# ---------------------------
# Metric Calculation Functions
# ---------------------------
def calculate_steepness(generator, noise, all=False):
    """
    Calculate the average L2 norm of the generator's gradients with respect to the input noise.
    
    Args:
        generator (torch.nn.Module): Generator model.
        noise (Tensor): Input noise.
        all (bool): If True, return all gradient norms; otherwise, return the average.
    
    Returns:
        float or list: Average steepness or list of gradient norms.
    """
    noise.requires_grad_(True)
    fake_images = generator(noise)

    jacobian_norms = []
    for i in range(fake_images.size(0)):
        generator.zero_grad()
        output = fake_images[i].view(-1).sum()
        output.backward(retain_graph=True)
        jacobian = noise.grad[i].detach()
        jacobian_matrix = jacobian.view(-1, noise.size(1))
        jacobian_norm = torch.linalg.norm(jacobian_matrix, ord=2).item()
        jacobian_norms.append(jacobian_norm)

    steepness = sum(jacobian_norms) / len(jacobian_norms)
    return steepness if not all else jacobian_norms


def calculate_frobenius_norm(discriminator, batch_x):
    """
    Calculate the 90th percentile of the normalized gradient's L2 norm (Frobenius norm) for a batch.

    Args:
        discriminator (torch.nn.Module): Discriminator model.
        batch_x (Tensor): Batch of input images.
    
    Returns:
        float: 90th percentile Frobenius norm value.
    """
    # Ensure the input tensor tracks gradients
    batch_x.requires_grad_(True)
    
    # Compute discriminator output and flatten to 1D
    d_x = discriminator(batch_x).view(-1)
    grad_outputs = torch.ones_like(d_x)
    
    # Compute gradients for each sample
    d_x.backward(grad_outputs)
    grad_d_x = batch_x.grad

    # Reshape d_x to allow normalization across other dimensions
    d_x = d_x.view(-1, *[1] * (grad_d_x.dim() - 1))
    epsilon = 0  # To avoid division by zero if needed
    normalized_grad = grad_d_x / (d_x + epsilon)
    
    # Compute the L2 norm of normalized gradients for each sample
    grad_norms = torch.norm(normalized_grad, p=2, dim=tuple(range(1, grad_d_x.dim())))
    return torch.quantile(grad_norms, q=0.9).item()


def calculate_loss(generator, discriminator, real_images, noise, device):
    """
    Calculate discriminator and generator losses.
    
    Args:
        generator (torch.nn.Module): Generator model.
        discriminator (torch.nn.Module): Discriminator model.
        real_images (Tensor): Real images.
        noise (Tensor): Input noise for generator.
        device (torch.device): Device to run computations.
    
    Returns:
        tuple: Discriminator loss and generator loss as floats.
    """
    fake_images = generator(noise).detach()
    real_labels = torch.ones(real_images.size(0), device=device)
    fake_labels = torch.zeros(fake_images.size(0), device=device)

    output_real = discriminator(real_images)
    output_fake = discriminator(fake_images)

    d_loss_real = torch.nn.functional.binary_cross_entropy(output_real, real_labels)
    d_loss_fake = torch.nn.functional.binary_cross_entropy(output_fake, fake_labels)
    d_loss = d_loss_real + d_loss_fake

    output_fake = discriminator(generator(noise))
    g_loss = torch.nn.functional.binary_cross_entropy(output_fake, real_labels)

    return d_loss.item(), g_loss.item()


def evaluate(generator, discriminator, noise_dim, eval_seed, save_dir, epoch, batch=None, metrics_file="metrics.txt"):
    """
    Evaluate generator and discriminator using steepness and Frobenius norm metrics.
    
    Args:
        generator (torch.nn.Module): Generator model.
        discriminator (torch.nn.Module): Discriminator model.
        noise_dim (int): Dimensionality of the latent noise.
        eval_seed (int): Base seed for evaluation.
        save_dir (str): Directory to save evaluation metrics.
        epoch (int): Current epoch number.
        batch (int, optional): Batch number (if applicable).
        metrics_file (str): Name of the metrics file.
    """
    device = next(generator.parameters()).device

    # Create a separate random generator for evaluation on the appropriate device
    rng = torch.Generator(device=device)
    rng.manual_seed(eval_seed + epoch * 1000 + (batch if batch is not None else 0))

    # Generate evaluation noise and compute metrics
    eval_noise = torch.randn(1280, noise_dim, 1, 1, generator=rng, device=device)
    steepness = calculate_steepness(generator, eval_noise)

    with torch.no_grad():
        fake_images = generator(eval_noise)
    frobenius_norm = calculate_frobenius_norm(discriminator, fake_images)

    os.makedirs(save_dir, exist_ok=True)
    metrics_path = os.path.join(save_dir, metrics_file)

    # Append evaluation metrics to file
    with open(metrics_path, "a") as f:
        f.write(f"Epoch: {epoch}, Batch: {batch if batch is not None else 'N/A'}\n")
        f.write(f"Steepness: {steepness:.4f}\n")
        f.write(f"Frobenius Norm: {frobenius_norm:.4f}\n")
        f.write("-" * 40 + "\n")

    print(f"Appended evaluation metrics to: {metrics_path}")
