import os
from typing import Callable, Dict, Tuple
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import torch
from pydantic import BaseModel


def vpd(t: float, h: float) -> float:
    """
    Calculates the vapor pressure deficit (VPD).

    Args:
        t (float): Air temperature in degrees Celsius.
        h (float): Relative humidity in percentage (0.0-1.0).

    Returns:
        float: Vapor pressure deficit in units of kPa.
    """
    e_sat = 0.611 * np.exp((17.27 * t) / (t + 237.3))
    e_act = h / 100.0 * e_sat
    return e_sat - e_act


def init_relation(config: Dict) -> Callable[[float, float], float]:
    """
        Initializes the relation function for the given dataset.

        Args:
            config (Dict): A dictionary containing the configuration.

        Returns:
            A callable function that represents the relation between the first two variables for the given dataset.

        Raises:
            NotImplementedError: If the specified dataset name is not supported.
        """

    # x_gt relations
    if config['data_name'] == 'rwth':
        return vpd
    elif config['data_name'] == 'recl':
        return lambda x, y: x - y
    elif config['data_name'] == 'rett':
        return lambda x, y: x * y
    elif config['data_name'] == 'rtraffic':
        return lambda x, y: x + y
    elif config['data_name'] == 'rillness':
        return lambda x, y: x - y
    else:
        raise NotImplementedError(f"data_name ({config['data_name']}) not implemented.")


def init_sequences(config: dict, mode: str) -> torch.Tensor:
    """
    Loads and initializes sequences from a CSV file.

    Args:
        config (dict): A dictionary containing the configuration
        mode (str): The data mode ('train', 'val', or 'test').

    Returns:
        A torch.Tensor containing the loaded sequences.
    """

    return torch.tensor(
        pd.read_csv(
            os.path.join('data', config['data_name'], f'{mode}.csv')
        ).iloc[:, 1:].values,
        dtype=torch.float32
    ).contiguous()


class GaussianDiffusion:
    class Config(BaseModel):
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2
    default_config = Config().dict()

    def __init__(self, config):
        self.config = self.Config(**config).dict()

        # parameters
        self.n_steps = config['n_steps']
        self.beta_first = config['beta_first']
        self.beta_last = config['beta_last']

        # init
        self.betas = torch.linspace(self.beta_first, self.beta_last, self.n_steps)
        self.alphas = 1 - self.betas
        self.alphas_sqrt = self.alphas.sqrt()
        self.alphas_cumprod = torch.cumprod(self.alphas, 0)
        self.alphas_cumprod_sqrt = self.alphas_cumprod.sqrt()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward_step(self, t, x):
        """
            forward step: t-1 -> t
        """
        assert t >= 0

        mean = self.alphas_sqrt[t] * x
        std = self.betas[t].sqrt()

        noise = torch.randn_like(x)
        return mean + std * noise, noise

    def forward_zero2t(self, t: int, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Performs a forward diffusion step from 0 to t.

        Args:
            t (int): The diffusion time step to go to.
            x (torch.Tensor): The input tensor.

        Returns:
            A tuple containing:
                - The diffused tensor.
                - The noise added.
        """
        assert t >= 0

        mean = self.alphas_cumprod_sqrt[t] * x
        std = (1 - self.alphas_cumprod[t]).sqrt()

        # sampling from N
        noise = torch.randn_like(x)
        x_t = mean + std * noise
        return x_t.to(self.device), noise.to(self.device)

    def guided_reverse_step(
            self,
            epsilon: torch.Tensor,
            x_t: torch.Tensor,
            t: int,
            x_prompt: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Performs a single reverse diffusion step with guidance going from t to t-1.

        Args:
            epsilon (torch.Tensor): The noise estimate at the current time step.
            x_t (torch.Tensor): The current state of the diffusion process.
            t (int): The current diffusion time step.
            x_prompt (torch.Tensor): The prompt (conditioning).

        Returns:
            A tuple containing:
                - The sampled tensor denoised by one time step.
                - The predicted mean denoised by one time step.
        """

        mean_pred = (x_t - epsilon * ((1 - self.alphas[t]) / (1 - self.alphas_cumprod[t]).sqrt())) / self.alphas_sqrt[t]
        if t == 0:
            sample = mean_pred
        else:
            # add noise
            noise = torch.randn_like(mean_pred)
            sample = mean_pred + self.betas[t].sqrt() * noise   # add noise to response and prompt

            # overwrite prompt with forward process
            prompt_indices = torch.where(~torch.isnan(x_prompt))
            x_prompt_noisy_t, _ = self.forward_zero2t(t, x_prompt)
            sample[prompt_indices] = x_prompt_noisy_t[prompt_indices].to(sample.device)
        return sample, mean_pred

    def reverse_step(
            self,
            epsilon: torch.Tensor,
            x_t: torch.Tensor,
            t: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Performs a single reverse diffusion step from t to t-1.

        Args:
            epsilon (torch.Tensor): The noise estimate at the current time step.
            x_t (torch.Tensor): The current state of the diffusion process.
            t (int): The current diffusion time step.

        Returns:
            A tuple containing:
                - The sampled tensor denoised by one time step.
                - The predicted mean denoised by one time step.
        """

        mean_pred = (x_t - epsilon * ((1 - self.alphas[t]) / (1 - self.alphas_cumprod[t]).sqrt())) / self.alphas_sqrt[t]
        if t == 0:
            sample = mean_pred
        else:
            # add noise
            noise = torch.randn_like(mean_pred)
            sample = mean_pred + self.betas[t].sqrt() * noise
        return sample, mean_pred

    def t2float(self, t: int) -> torch.Tensor:
        """
        Converts an integer time step to a normalized float value.

        Args:
            t (int): The integer time step.

        Returns:
            A torch.Tensor containing the normalized float value of the time step, scaled to the range [-0.5, 0.5].
        """
        return torch.tensor([(t - 0.5 * self.n_steps) / self.n_steps], dtype=torch.float32).to(self.device)

    def sample(self, n_samples: int, model: torch.nn.Module) -> Tuple[torch.Tensor, dict]:
        """
        Samples the diffusion model.

        Args:
            model (torch.nn.Module): The model to be used for sampling.

        Returns:

            The sample (torch.Tensor).
        """
        x_t = torch.randn(n_samples, 2).to(self.device)

        # sample t=T -> t=1
        e_p = 0  # for computing prompt_ts
        response_trace = []     # for computing response_ts
        for t in range(self.n_steps - 1, 0, -1):
            model_input = torch.cat([x_t.to(self.device), self.t2float(t).unsqueeze(0).repeat(n_samples, 1)], dim=1)
            epsilon = model(model_input).detach().cpu()
            x_t, mean_pred = self.reverse_step(epsilon, x_t.detach().cpu(), t)

        # t=0 sample
        t = 0
        model_input = torch.cat([x_t.to(self.device), self.t2float(0).unsqueeze(0).repeat(n_samples, 1)], dim=1).to(self.device)
        epsilon = model(model_input).detach().cpu()
        x_t, mean_pred = self.reverse_step(epsilon, x_t.detach().cpu(), t)
        return x_t

    def guided_sample(self, model: torch.nn.Module, x_prompt: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Samples the diffusion model using RePaint.

        Args:
            model (torch.nn.Module): The model to be used for sampling.
            x_prompt (torch.Tensor): The input prompt tensor for conditioning.

        Returns:
            A tuple containing:
                - The sampled prompt-response pair (torch.Tensor).
                - A dictionary containing information for other unused metrics.
        """

        batch_size, _ = x_prompt.shape
        x_t = torch.randn_like(x_prompt).to(self.device)

        # prompt/response splits
        prompt_indices = np.where(np.isnan(x_prompt) == 0)
        response_indices = np.where(np.isnan(x_prompt) == 1)
        prompt_idx_splits = np.cumsum([0] + [len(np.where(np.isnan(x_prompt_) == 0)[0]) for x_prompt_ in x_prompt])
        response_idx_splits = np.cumsum([0] + [len(np.where(np.isnan(x_prompt_) == 1)[0]) for x_prompt_ in x_prompt])

        # sample t=T -> t=1
        e_p = 0  # for computing prompt_ts
        response_trace = []     # for computing response_ts
        for t in range(self.n_steps - 1, 0, -1):
            model_input = torch.cat([x_t.to(self.device), self.t2float(t).unsqueeze(0).repeat(batch_size, 1)], dim=1)
            epsilon = model(model_input).detach().cpu()
            x_t, mean_pred = self.guided_reverse_step(epsilon, x_t.detach().cpu(), t, x_prompt)

            response_trace.append(mean_pred.detach().cpu().numpy()[response_indices])
            e_p += (mean_pred[prompt_indices] - x_prompt[prompt_indices].to(x_t.device)) ** 2

        # t=0 sample
        model_input = torch.cat([x_t.to(self.device), self.t2float(0).unsqueeze(0).repeat(batch_size, 1)], dim=1).to(self.device)
        epsilon = model(model_input).detach().cpu()
        x_t, mean_pred = self.guided_reverse_step(epsilon, x_t.detach().cpu(), 0, x_prompt)

        # prompt trajectory spread
        e_p += (mean_pred[prompt_indices] - x_prompt[prompt_indices].to(x_t.device)) ** 2
        e_p /= self.n_steps
        prompt_ts_list = [
            float(np.sqrt(e_p[prompt_idx_splits[i]:prompt_idx_splits[i + 1]]).mean()) for i in
            range(len(prompt_idx_splits) - 1)
        ]

        # response trajectory spread
        if len(response_indices[0]) == 0:       # for when entire data dimension is a prompt
            response_ts_list = []
        else:
            response_trace = np.stack(response_trace)
            error_response = np.std(response_trace, axis=0)
            response_ts_list = [
                float(error_response[response_idx_splits[i]:response_idx_splits[i + 1]].mean()) for i in
                range(len(response_idx_splits) - 1)
            ]

        # prompt error
        prompt_mse = (mean_pred[prompt_indices] - x_prompt[prompt_indices].to(x_t.device)) ** 2
        prompt_error_list = [
            float(np.sqrt(prompt_mse[prompt_idx_splits[i]:prompt_idx_splits[i + 1]].mean())) for i in
            range(len(prompt_idx_splits) - 1)
        ]

        # combined error
        model_input = torch.cat([x_t.to(self.device), self.t2float(0).unsqueeze(0).repeat(batch_size, 1)], dim=1).to(self.device)
        epsilon = model(model_input).detach().cpu()
        _, mean_pred_c = self.guided_reverse_step(epsilon, x_t.detach().cpu(), 0, x_prompt)
        combined_mse = (mean_pred_c - x_t) ** 2
        combined_error_list = torch.sqrt(combined_mse.mean(dim=1)).tolist()

        info = {
            'prompt_ts_list': prompt_ts_list,
            'response_ts_list': response_ts_list,
            'prompt_error_list': prompt_error_list,
            'combined_error_list': combined_error_list,
        }
        return x_t, info

    def compute_combined_error(self, model: torch.nn.Module, x_sampled: torch.Tensor) -> torch.Tensor:
        """
        Computes the combined error for the given prompt-response pairs.

        Args:
            model (torch.nn.Module): The diffusion model to be used to compute the combined error.
            x_sampled (torch.Tensor): The sampled prompt-response pairs (batch_size, input_length, n_sequences).

        Returns:
            A torch.Tensor containing the combined error for each sample in the batch (shape: (batch_size,)).
        """

        bs, _ = x_sampled.shape

        model_input = torch.cat([
            x_sampled.to(self.device),
            self.t2float(0).unsqueeze(0).repeat(bs, 1)
        ], dim=1).to(self.device)

        epsilon = model(model_input).detach().cpu()

        _, mean_pred = self.guided_reverse_step(epsilon, x_sampled, 0, x_sampled)
        return torch.sqrt(((mean_pred - x_sampled) ** 2).mean(dim=1))


class DiffusionDataset(Dataset):
    class Config(BaseModel):
        # dataset
        data_name: str = 'wth_vpd'
        input_length: int = 20

        # g_diff
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2

    default_config = Config().dict()
    allowed_modes = ['train', 'val', 'test']
    allowed_data_names = ['wth_vpd', 'ecl']

    def __init__(self, config, mode='train'):
        super().__init__()
        self.config = self.Config(**config).dict()
        self.mode = mode
        assert mode in self.allowed_modes, f'mode ({mode}) not in {self.allowed_modes}.'

        # attributes
        self.data_name = config['data_name']
        self.input_length = config['input_length']
        self.n_steps = config['n_steps']
        self.beta_first = config['beta_first']
        self.beta_last = config['beta_last']

        # checks
        assert mode in self.allowed_modes, f'mode={mode} not in {self.allowed_modes}.'
        assert mode in self.allowed_modes, f'data_name={self.data_name} not in {self.allowed_data_names}.'

        # init gaussian diffusion
        self.g_diff = GaussianDiffusion(config)

        # get train sequence statistics for normalisation
        train_sequences = init_sequences(config, 'train')
        self.data_means = train_sequences.mean(axis=0).reshape(1, -1)
        self.data_stds = train_sequences.std(axis=0).reshape(1, -1)

        # init sequences
        if mode == 'train':
            self.sequences = train_sequences
        else:
            del train_sequences
            self.sequences = init_sequences(config, mode)
        self.n_data, self.n_sequences = self.sequences.shape

    def get_raw_data(self, idx: int, concatenate: bool = False):
        """
        Retrieves raw (un-normalized) data.

        Args:
            idx (int): The starting index of the data to retrieve.
            concatenate (bool, optional): Concatenate the sequence into a single dimension. Defaults to False.

        Returns:
            A torch.Tensor containing the raw data.
        """

        if concatenate:
            return self.sequences[idx:idx + self.input_length].view(-1)
        else:
            return self.sequences[idx:idx + self.input_length]

    def __len__(self) -> int:
        """
        Returns the number data points available after subtracting the length of each data point.

        Returns:
            int: The number of data points
        """

        return 1 + self.sequences.shape[0] - self.input_length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieves a single data point from the dataset.

        Args:
            idx (int): The index of the data sample to retrieve.

        Raises:
            IndexError: If the index is out of bounds.

        Returns:
            A tuple containing:
                - The input tensor for the model, concatenating the noisy data and the normalized time step.
                - The noise added during for diffusion process.
        """

        if idx >= len(self):
            raise IndexError

        # clone so the in-place operations later doesn't modify it
        x = self.sequences[idx:idx + self.input_length].clone()

        # normalise
        x -= self.data_means
        x /= self.data_stds

        # add noise to the data
        t = np.random.randint(0, high=self.n_steps)
        x_noisy, noise = self.g_diff.forward_zero2t(t, x)

        # format
        x_noisy = x_noisy.view(-1)
        noise = noise.view(-1)
        return torch.concat([x_noisy, self.g_diff.t2float(t)]), noise


class DiffusionRContourDataset(Dataset):
    class Config(BaseModel):
        # g_diff
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2

    default_config = Config().dict()
    allowed_modes = ['train', 'val', 'test']

    def __init__(self, config, sequences):
        super().__init__()
        self.config = self.Config(**config).dict()

        # attributes
        self.n_steps = config['n_steps']
        self.beta_first = config['beta_first']
        self.beta_last = config['beta_last']

        # init gaussian diffusion
        self.g_diff = GaussianDiffusion(config)

        # sequences
        self.sequences = sequences
        self.n_data, self.n_sequences = self.sequences.shape

    def get_raw_data(self, idx: int):
        """
        Retrieves raw (un-normalized) data.

        Args:
            idx (int): The starting index of the data to retrieve.

        Returns:
            A torch.Tensor containing the raw data.
        """

        return self.sequences[idx]

    def __len__(self) -> int:
        """
        Returns the number data points available after subtracting the length of each data point.

        Returns:
            int: The number of data points
        """

        return self.sequences.shape[0]

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieves a single data point from the dataset.

        Args:
            idx (int): The index of the data sample to retrieve.

        Raises:
            IndexError: If the index is out of bounds.

        Returns:
            A tuple containing:
                - The input tensor for the model, concatenating the noisy data and the normalized time step.
                - The noise added during for diffusion process.
        """

        if idx >= len(self):
            raise IndexError

        # clone so the in-place operations later doesn't modify it
        x = self.sequences[idx].clone()

        # add noise to the data
        t = np.random.randint(0, high=self.n_steps)
        x_noisy, noise = self.g_diff.forward_zero2t(t, x)

        # format
        x_noisy = x_noisy.view(-1)
        noise = noise.view(-1)
        return torch.concat([x_noisy, self.g_diff.t2float(t)]), noise