import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from scipy.special import gammainc
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import torch
from global_config import ROOT_DIRECTORY
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np

plt.rcParams.update({'font.size': 14})


class Heterogeneous2DClusters(Dataset):
    def __init__(self, num_samples=10000, device='cpu'):
        """
        Generate `num_samples` points distributed among five clusters
        with different densities in the [-1,1]² space.
        """
        self.num_samples = num_samples
        self.device = device
        x_positions = np.linspace(-0.7, 0.7, 5)
        y_positions = -0.75 + np.exp(3 * (x_positions - 0.75))  # Quadratic function
        centers = np.column_stack((x_positions, y_positions))

        self.clusters = [
            {"center": (-0.0, -0.6), "std": 0.06, "prob": 0.989},
            {"center": (-0.6, 0.6), "std": 0.06, "prob": 0.01},
            # {"center": (-0.5, -0.75),   "std": 0.03, "prob": 0.499},
            # {"center": (-0.3, -0.5),   "std": 0.02, "prob": 0.2},
            # {"center": (-0.5, -0.5),   "std": 0.01, "prob": 0.199},
            {"center": (0.6, 0.6), "std": 0.06, "prob": 0.001}
        ]

        # Normalize probabilities to sum to 1
        total_prob = sum(cluster["prob"] for cluster in self.clusters)
        for cluster in self.clusters:
            cluster["prob"] /= total_prob

        # Generate data
        self.data = torch.tensor(self.sample_clusters(), dtype=torch.float32, device=device)

    def sample_clusters(self):
        """
        Generate samples from multiple Gaussian clusters.
        """
        samples = []
        for cluster in self.clusters:
            center, std, prob = cluster["center"], cluster["std"], cluster["prob"]
            num_cluster_samples = int(self.num_samples * prob)
            cluster_samples = np.random.normal(loc=center, scale=std, size=(num_cluster_samples, 2))
            samples.append(cluster_samples)
        return np.vstack(samples)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]


def create_heterogeneous_2d_dataloader(batch_size, num_samples=100000, num_workers=4, device='cpu', dataset_name="_"):
    """
        Create DataLoader for multi-sphere dataset.
    """
    if os.path.exists(os.path.join(ROOT_DIRECTORY, "results", "h2c", dataset_name)):
        dataset = torch.tensor(
            np.load(os.path.join(ROOT_DIRECTORY, "results", "h2c", dataset_name)),
            dtype=torch.float32
        )
        # Use a separate random generator for this function
        generator = torch.Generator().manual_seed(42)
        train_size = int(0.5 * len(dataset))
        train_dataset, test_dataset = random_split(
            TensorDataset(dataset), [train_size, len(dataset) - train_size], generator=generator
        )
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_dataloader, test_dataloader

    dataset = Heterogeneous2DClusters(num_samples, device=device)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_dataset = Heterogeneous2DClusters(num_samples, device=device)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return loader, test_loader


def visualize_clusters_2d_deprecated(data, losses=None, cbar_min=None, cbar_max=None,
                                     title="Heterogeneous 2D Clusters (Quadratic Centers)", savedir=None,
                                     density_plot=True):
    """
    Visualizes 2D clusters with options for density and loss-based coloring.

    Parameters:
        - data (torch.Tensor or np.ndarray): 2D data points.
        - losses (torch.Tensor or np.ndarray, optional): Loss values for color encoding.
        - cbar_min, cbar_max (float, optional): Min/max limits for the colorbar.
        - title (str): Plot title.
        - savedir (str, optional): Directory to save the plot.
        - density_plot (bool): If True, generates a density plot.
    """
    if isinstance(data, torch.Tensor):
        data_np = data.cpu().numpy()
    else:
        data_np = data

    x, y = data_np[:, 0], data_np[:, 1]

    fig, ax = plt.subplots(figsize=(8, 7))

    if density_plot:
        # Density plot using hexbin
        hb = ax.hexbin(x, y, gridsize=50, cmap='plasma', bins='log')

        # Add colorbar
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = plt.colorbar(hb, cax=cax)
        cbar.set_label("Log Density")
        if cbar_min is not None:
            cbar.mappable.set_clim(cbar_min, cbar_max)

    else:
        if losses is not None:
            # Scatter plot with loss-based coloring
            if isinstance(losses, torch.Tensor):
                losses = losses.cpu().numpy()

            scatter = ax.scatter(x, y, c=losses, cmap='plasma', alpha=0.8, vmin=cbar_min, vmax=cbar_max)

            # Add colorbar
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.1)
            cbar = plt.colorbar(scatter, cax=cax)
            cbar.set_label("Loss Value")
            if cbar_min is not None:
                cbar.mappable.set_clim(cbar_min, cbar_max)
        else:
            # Simple scatter plot
            ax.scatter(x, y, alpha=0.1, s=5)

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_ylim(-1, 1)
    ax.set_xlim(-1, 1)
    ax.set_title(title)

    if savedir:
        plt.savefig(f"{savedir}/{title}.jpg", bbox_inches="tight")

    plt.tight_layout()
    plt.show()


def visualize_clusters_2d(data, losses=None, cbar_min=None, cbar_max=None,
                          title="Heterogeneous 2D Clusters (Quadratic Centers)", savedir=None, density_plot=True,
                          model_params=None):
    """
    Visualizes 2D clusters with options for density and loss-based coloring.
    Also plots a linear function y = ax + b if model_params containing 'weights' and 'bias' is provided.

    Parameters:
        - data (np.ndarray): 2D data points (N x 2).
        - losses (np.ndarray, optional): Loss values for color encoding.
        - cbar_min, cbar_max (float, optional): Min/max limits for the colorbar.
        - title (str): Plot title.
        - savedir (str, optional): Directory to save the plot.
        - density_plot (bool): If True, generates a density plot.
        - model_params (dict, optional): Dictionary containing model parameters with keys:
            - 'weights': NumPy array or tensor of shape (1, degree) (only first weight is used as slope).
            - 'bias': NumPy array or tensor of shape (1,).

    """
    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()

    x, y = data[:, 0], data[:, 1]

    fig, ax = plt.subplots(figsize=(8, 7))

    if density_plot:
        # Density plot using hexbin
        hb = ax.hexbin(x, y, gridsize=50, cmap='plasma', bins='log')

        # Add colorbar
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = plt.colorbar(hb, cax=cax)
        cbar.set_label("Log Density")
        if cbar_min is not None:
            cbar.mappable.set_clim(cbar_min, cbar_max)

    else:
        if losses is not None:
            if isinstance(losses, torch.Tensor):
                losses = losses.cpu().numpy()

            scatter = ax.scatter(x, y, c=losses, cmap='plasma', alpha=0.8, vmin=cbar_min, vmax=cbar_max)

            # Add colorbar
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.1)
            cbar = plt.colorbar(scatter, cax=cax)
            cbar.set_label("Loss Value")
            if cbar_min is not None:
                cbar.mappable.set_clim(cbar_min, cbar_max)
        else:
            ax.scatter(x, y, alpha=0.1, s=5)

    # Plot the linear function y = ax + b if model_params are given
    # Plot function y = ax^2 + bx + c if model_params are given
    if model_params and 'weights' in model_params and 'bias' in model_params:
        weight = model_params['weights']
        bias = model_params['bias']

        if isinstance(weight, torch.Tensor):
            weight = weight.cpu().numpy()
        if isinstance(bias, torch.Tensor):
            bias = bias.cpu().numpy()

        x_line = np.linspace(-1, 1, 100)

        if weight.ndim == 2 and weight.shape[1] == 2:
            # Quadratic case: y = ax^2 + bx + c
            b, a = weight[0]  # Extract coefficients
            c = bias[0]
            y_line = a * x_line ** 2 + b * x_line + c
            label = f"y = {a:.2f}x² + {b:.2f}x + {c:.2f}"
        else:
            # Linear case: y = ax + b
            slope = weight[0, 0] if weight.ndim == 2 else weight
            intercept = bias[0] if bias.ndim == 1 else bias
            y_line = slope * x_line + intercept
            label = f"y = {slope:.2f}x + {intercept:.2f}"

        ax.plot(x_line, y_line, color="red", linestyle="--", label=label)
        ax.legend()

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_ylim(-1, 1)
    ax.set_xlim(-1, 1)
    ax.set_title(title)

    if savedir:
        plt.savefig(f"{savedir}/{title}.jpg", bbox_inches="tight")

    plt.tight_layout()
    plt.show()


# Example Usage
if __name__ == "__main__":
    batch_size = 1000000
    num_samples = 1000000  # Increase for better density representation

    loader, _ = create_heterogeneous_2d_dataloader(batch_size, num_samples)

    # Fetch one batch
    data_batch = next(iter(loader))

    # Visualize the dataset
    visualize_clusters_2d(data_batch, title="Heterogeneous 2D Clusters (Quadratic Centers)")
