import os

import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm


def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)


def create_directories(script_dir):
    dir_name = script_dir / "plots"
    os.makedirs(dir_name, exist_ok=True)
    model_dir = script_dir / "trained_models"
    os.makedirs(model_dir, exist_ok=True)
    data_dir = script_dir / "data"
    os.makedirs(data_dir, exist_ok=True)


def model_signature(args, dataset):
    training_settings = f"_manifold{args.manifold}_ndim{args.x_dim}_nepochs{args.n_epochs}"  # _fourier{args.fourier_sigma}"
    temperature = f"_T0{args.T0}_Tn{args.Tn}"
    seed = f"_seed{args.seed}"
    if args.training_mode in ["log_likelihood", "score_matching"]:
        signature = dataset + args.training_mode + training_settings + seed
    elif args.training_mode in ["kl_divergence", "symmetric_kl"]:
        signature = dataset + args.training_mode + training_settings + temperature + seed
    else:
        raise ValueError(f"Unknown training mode: {args.training_mode}")

    return signature


def check_tuple(data, move_to_torch=False, device=None, one_hot=False):
    if isinstance(data, tuple) or (isinstance(data, list) and len(data) == 2):
        observations, context = data
        if move_to_torch and isinstance(observations, np.ndarray):
            observations = torch.from_numpy(observations).float()
            context = torch.from_numpy(context).float()
        return observations.to(device), context.to(device)
    else:
        if move_to_torch and not isinstance(data, torch.Tensor):
            data = torch.from_numpy(data).float()
        return data.to(device), None


class ConditionalDataset(torch.utils.data.Dataset):
    def __init__(self, data, context):
        if data.shape[0] != context.shape[0]:
            raise ValueError("Number of conditions in data and context must match.")

        self.data = data
        self.context = context
        self.n_cond = data.shape[0]

    def __len__(self):
        # Now, each "item" from the dataset is one condition
        return self.n_cond

    def __getitem__(self, idx):
        # When an index is requested, return all samples for that condition
        # and the corresponding context for that condition.
        data = self.data[idx, :, :]  # Shape: [n_samples_per_condition, n_dimensions]
        context = self.context[idx, :]  # Shape: [n_dimensions_conditions]

        return data, context


class FlexibleDataset(Dataset):
    """
    A flexible PyTorch Dataset class that handles both unconditional and
    conditional data.

    The dataset can be used in two modes:
    1. Unconditional: `__getitem__` returns only the data sample.
    2. Conditional: `__getitem__` returns a tuple of (data_sample, condition).

    This structure is designed to be seamlessly compatible with PyTorch's DataLoader.
    """

    def __init__(self, data, conditions=None):
        """
        Initializes the dataset.

        Args:
            data (torch.Tensor or np.ndarray):
                A tensor containing all data samples.
                Shape: (n_total_samples, n_data_dims)

            conditions (torch.Tensor or np.ndarray, optional):
                A tensor containing the corresponding condition for each data sample.
                If provided, the dataset operates in conditional mode.
                Shape: (n_total_samples, n_condition_dims)
                Defaults to None for unconditional data.
        """
        # Ensure data is a torch.Tensor
        if not isinstance(data, torch.Tensor):
            self.data = torch.from_numpy(data).float()
        else:
            self.data = data.float()

        # Store conditions if they exist, and set the mode
        self.is_conditional = conditions is not None
        if self.is_conditional:
            if not isinstance(conditions, torch.Tensor):
                self.conditions = torch.from_numpy(conditions).float()
            else:
                self.conditions = conditions.float()

            # --- Data Validation ---
            # Ensure that for every data point, there is a corresponding condition.
            if self.data.shape[0] != self.conditions.shape[0]:
                raise ValueError(
                    "Mismatch in number of samples between data and conditions. "
                    f"Got {self.data.shape[0]} data samples and "
                    f"{self.conditions.shape[0]} conditions."
                )

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return self.data.shape[0]

    def __getitem__(self, idx):
        """
        Retrieves a sample from the dataset at the specified index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            - If unconditional: a single tensor (the data sample).
            - If conditional: a tuple of (data_sample, condition).
        """
        data_sample = self.data[idx]

        if self.is_conditional:
            condition_sample = self.conditions[idx]
            return data_sample, condition_sample
        else:
            return data_sample


class TemperatureSchedule:
    def __init__(
        self,
        temp_max,
        num_iter,
        cooling_schedule="log",
        temp_min_it_ratio=0.9,
        temp_min=1.0,
        temp_factor=1.0,
    ):
        self.temp_min = temp_min
        self.temp_max = temp_max
        self.temp_min_it = max(int(num_iter * temp_min_it_ratio), 1)
        self.temp_factor = temp_factor  # multiplies temperature above the offset

        self.cs = cooling_schedule
        if self.cs == "quad":
            k = np.sqrt(self.temp_max / self.temp_min - 1)
            self.temp_it_factor = k / self.temp_min_it
        elif self.cs == "log":
            k = np.exp(self.temp_max / self.temp_min - 1) - 1
            self.temp_it_factor = k / self.temp_min_it
        elif self.cs == "custom":
            self.temp_it_factor = (
                np.exp((self.temp_max * (1 + np.log(self.temp_min))) / self.temp_min - 1.0)
                - self.temp_min
            )
            self.temp_it_factor = self.temp_it_factor / self.temp_min_it

    def mult_temperature(self, temperature):
        # caps minimum temperature at minimal temperature and adds a factor to the amount higher than min
        return self.temp_factor * np.maximum(temperature - self.temp_min, 0.0) + self.temp_min

    def get_temperature(self, iteration):
        if self.cs == "custom":
            temperature = (self.temp_max * (1 + np.log(self.temp_min))) / (
                1 + np.log(self.temp_it_factor * iteration + self.temp_min)
            )
            temperature = np.maximum(temperature, self.temp_min)
            temperature = (self.temp_factor * self.temp_max / (self.temp_max - 1.0) - 1.0) * (
                temperature - self.temp_min
            ) + self.temp_min
            return temperature

        if self.cs == "exp_mult":
            k = iteration / self.temp_min_it
            temperature = self.temp_max * (self.temp_min / self.temp_max) ** k
        elif self.cs == "quad":
            k = iteration * self.temp_it_factor
            temperature = self.temp_max / (1 + k**2)
        elif self.cs == "log":
            k = iteration * self.temp_it_factor
            temperature = self.temp_max / (1 + np.log(1 + k))
        else:
            temperature = self.temp_min
        return self.mult_temperature(temperature)


@torch.no_grad()
def batched_evaluation(data, batch_size, function):
    n_points = data.shape[0]
    n_batches = n_points // batch_size
    results = []
    for i in range(n_batches + 1):
        start = i * batch_size
        end = min((i + 1) * batch_size, n_points)
        result = function(data[start:end]).cpu().numpy()
        results.append(result)

    return np.concatenate(results)


@torch.no_grad()
def batched_sampling(n_points, batch_size, sampling_function):
    n_batches = n_points // batch_size if n_points > batch_size else 1
    results = []
    for i in range(n_batches):
        result = sampling_function(batch_size).cpu().numpy()
        results.append(result)

    return np.concatenate(results)
