import math
import os

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from ntldm.utils.dataset_utils import load_vae_data
from ntldm.utils.utils import standardize_array
from ntldm.networks.networks import SinusoidalPosEmb
from tqdm.auto import tqdm

from torch.utils.data import Subset


def simulate_lds(
    n_ic, n_reps, sequence_length, n_latents, rotation_angles, noise_variance, seed=42
):
    """
    simulate an lds with specified rotation angles for the dynamics matrix 'a' and add gaussian noise

    parameters:
        n_ic (int): number of initial conditions
        n_reps (int): number of repetitions for each initial condition
        sequence_length (int): length of each sequence to generate
        n_latents (int): dimensionality of the latent space
        rotation_angles (list of floats): angles for constructing the periodic a matrix
        noise_variance (float): variance of the process noise
        seed (int, optional): random seed for reproducibility

    returns:
        torch.tensor: tensor of shape (n_ic * n_reps, sequence_length, n_latents) containing the generated sequences
    """
    # set the random seed for reproducibility
    torch.manual_seed(seed)

    # construct the periodic a matrix using rotation blocks
    blocks = [rotation_matrix(theta) for theta in rotation_angles]
    a = torch.block_diag(*blocks)

    # initialize the initial conditions with a bias pattern
    ic_bias = torch.ones(n_latents)
    ic_bias[::2] = 0  # every second entry is 0, the rest are 1
    initial_states = torch.randn(n_ic, n_latents) + ic_bias

    # prepare storage for all latents across initial conditions and repetitions
    latents = torch.empty(
        (n_ic * n_reps, sequence_length, n_latents), dtype=torch.float32
    )
    pbar = tqdm(total=n_ic * n_reps, desc="Simulating LDS")
    for i in range(n_ic):  # loop over initial conditions
        for rep in range(n_reps):  # loop over repetitions
            x = initial_states[i]
            idx = i * n_reps + rep
            for t in range(sequence_length):  # simulate dynamics over time
                latents[idx, t] = x
                # apply the linear dynamics and add noise
                x = a @ x + torch.randn(n_latents) * torch.sqrt(
                    torch.tensor(noise_variance)
                )
            pbar.update(1)
    pbar.close()

    return latents


def rotation_matrix(theta):
    """
    Create a 2D rotation matrix for the given angle theta.

    Parameters:
        theta (float): Rotation angle in radians.

    Returns:
        torch.tensor: 2x2 rotation matrix.
    """
    # Ensure theta is a tensor
    theta = torch.tensor(theta)

    # Compute cosine and sine of theta
    c, s = torch.cos(theta), torch.sin(theta)

    # Return the rotation matrix
    return torch.tensor([[c, -s], [s, c]])


class LDS_VAE_OUTPUT(Dataset):
    """
    Dataset of the embedded LDS time series

    Supports positional embeddings.
    """

    def __init__(
        self,
        with_time_emb=False,
        cond_time_dim=32,
        filepath=None,
        filename=None,
        n_latents=4,
        model=None,
    ):
        super().__init__()

        self.with_time_emb = with_time_emb
        self.cond_time_dim = cond_time_dim
        self.signal_length = 200
        self.num_channels = n_latents

        # TODO: replace this load vae data with the load LDS data
        # and the model weights of the specified run and embed the data
        if model is None:
            self.model = model
        dataset = load_vae_data(filepath + filename)
        temp_array = dataset["mu"].numpy()  # load the training embessinds
        temp_array = temp_array.transpose(0, 2, 1)  #

        self.data_array = standardize_array(temp_array, ax=(0, 2))

        temp_emb = SinusoidalPosEmb(cond_time_dim).forward(
            torch.arange(self.signal_length)
        )
        self.emb = torch.transpose(temp_emb, 0, 1)

    def __getitem__(self, index, cond_channel=None):
        return_dict = {}
        return_dict["signal"] = torch.from_numpy(np.float32(self.data_array[index]))
        cond = self.get_cond()
        if cond is not None:
            return_dict["cond"] = cond
        return return_dict

    def get_cond(self):
        cond = None
        if self.with_time_emb:
            cond = self.emb
        return cond

    def __len__(self):
        return len(self.data_array)


class LinearDynamicalSystem(Dataset):
    def __init__(
        self,
        n_latents,
        d_dimension,
        rotation_angles,
        seed=0,
        sequence_length=100,
        noise_variance=0.05,
        n_ic=5,
        n_reps=10,
        mean_spike_count=500.0,
        random_seed=42,
        softplus_beta=1.0,
        time_last=False,
    ):
        """
        Initialize the LDS with periodic dynamics.

        :param n_latents: Number of latent dimensions (should be even for simplicity)
        :param d_dimension: Dimensionality of the data space
        :param total_length: Length of the sequence to generate
        :param rotation_angles: List of angles for constructing the periodic A matrix
        :param seed: Random seed for reproducibility
        :param sequence_length: Length of the short sequences
        :
        """
        assert n_latents % 2 == 0, "n_latents should be even."
        assert (
            len(rotation_angles) == n_latents // 2
        ), "Provide an angle for each 2D rotation block."

        self.n_latents = n_latents
        self.d_dimension = d_dimension
        self.sequence_length = sequence_length

        self.time_last = time_last

        # # Construct the periodic A matrix
        # blocks = [self.rotation_matrix(theta) for theta in rotation_angles]
        # self.A = torch.block_diag(*blocks)

        # # keep the same C matrix for all the data
        # torch.manual_seed(seed)

        # # get a troch tensor of the latents where everysecond entry is 0 and the other is 1
        # ic_bias = torch.ones(n_latents)
        # ic_bias[::2] = 0
        # # Initial states sampled from a standard 2D Gaussian
        # self.initial_states = torch.randn(n_ic, n_latents) + ic_bias

        # # Generate sequences for each initial condition
        # latents = torch.empty(
        #     (n_ic * n_reps, sequence_length, n_latents), dtype=torch.float32
        # )
        # for i in range(n_ic):  # Loop over initial conditions
        #     for rep in range(n_reps):  # Loop over repetitions
        #         x = self.initial_states[i]
        #         for t in range(sequence_length):
        #             # Index for the current sequence
        #             idx = i * n_reps + rep
        #             latents[idx, t] = x
        #             x = self.A @ x + torch.randn(n_latents) * torch.sqrt(
        #                 torch.tensor(noise_variance)
        #             )

        latents = simulate_lds(
            n_ic=n_ic,
            n_reps=n_reps,
            sequence_length=sequence_length,
            n_latents=n_latents,
            rotation_angles=rotation_angles,
            noise_variance=noise_variance,
        )

        # # Projection matrix to higher-dimensional data space
        self.C = torch.randn(d_dimension, n_latents, dtype=torch.float32)
        self.C /= torch.norm(self.C, dim=1, keepdim=True) * 2

        # # Generate Poisson rates and samples

        self.C = self.C[
            self.C[:, 0].argsort()
        ]  # Sort based on the first column, similar to provided numpy code

        self.b = torch.log(
            torch.tensor(mean_spike_count) / sequence_length
        ) * torch.ones(d_dimension, 1)

        # Compute log rates and apply clipping
        # log_rates = self.C @ latents.view(-1, n_latents).T + self.b
        self.log_rates = torch.einsum("ij,klj->kli", self.C, latents) + self.b.squeeze()

        self.poisson_rates = torch.nn.functional.softplus(
            self.log_rates, beta=softplus_beta
        )
        self.samples = torch.poisson(self.poisson_rates)

        # Store for dataset
        self.latents = latents
        self.rates = self.poisson_rates

    @staticmethod
    def rotation_matrix(theta):
        """
        Create a 2D rotation matrix for the given angle theta.
        """
        return torch.tensor(
            [[math.cos(theta), -math.sin(theta)], [math.sin(theta), math.cos(theta)]],
            dtype=torch.float32,
        )

    def __len__(self):
        """
        Return the length.
        """
        return len(self.samples)

    def __getitem__(self, index, cond_channel=None):
        """
        Get a sample from the dataset.
        """
        return_dict = {}
        return_dict["signal"] = self.samples[index]
        return_dict["latents"] = self.latents[index]
        return_dict["rates"] = self.rates[index]
        if self.time_last:
            return_dict["signal"] = return_dict["signal"].permute(1, 0)
            return_dict["latents"] = return_dict["latents"].permute(1, 0)
            return_dict["rates"] = return_dict["rates"].permute(1, 0)

        cond = self.get_cond()
        if cond is not None:
            return_dict["cond"] = cond

        return return_dict

    def get_cond(self):
        # TODO: implement time embedding if ever relevant
        cond = None
        return cond


def sequential_split(dataset, lengths):
    """
    sequentially splits a dataset into non-overlapping new datasets
    :param dataset: input dataset which is a torch.utils.data.Dataset
    :param lengths: lengths of splits to be produced, should sum to the length of the dataset
    :return: a list of torch.utils.data.Subset
    """
    # ensure the lengths sum up to the total length of the dataset
    assert sum(lengths) == len(
        dataset
    ), "Sum of input lengths does not equal the total length of the dataset"

    # generate split points
    indices = torch.arange(0, len(dataset))
    return [
        Subset(dataset, indices[offset - length : offset])
        for offset, length in zip(torch.cumsum(torch.tensor(lengths), 0), lengths)
    ]


def get_lds_dataset(
    n_latents=4,
    n_neurons=128,
    sequence_length=200,
    rotation_angles=[math.pi / 13, math.pi / 99],
    n_ic=1000,
    n_reps=10,
    noise_variance=0.02,
    mean_spike_count=100,
    batch_size=100,
    train_frac=0.7,
    valid_frac=0.15,
    random_seed=42,
    time_last=False,
    softplus_beta=1.0,
):
    """
    Generates and splits data from a linear dynamical system into train, val, and test sets.
    Corrects the validation data fraction to be relative to the total dataset size.

    Parameters:
    - n_latents: Number of latent dimensions.
    - n_neurons: Dimensionality of observations.
    - sequence_length: Length of each sequence.
    - rotation_angles: Angles for 2D rotation blocks.
    - n_ic: Number of initial conditions.
    - n_reps: Number of repetitions per initial condition.
    - noise_variance: Variance of observation noise.
    - mean_spike_count: Mean spike count for Poisson observation model (not used in placeholder).
    - batch_size: Batch size for data loaders.
    - train_frac: Fraction of data to use for training.
    - valid_frac: Fraction of data to use for validation (of total data).
    - random_seed: Seed for random number generator for reproducibility.

    Returns:
    - A tuple of DataLoader objects for training, validation, and test sets.
    """
    n_seqs = n_ic * n_reps
    train_seqs = int(train_frac * n_seqs)
    valid_seqs = int(valid_frac * n_seqs)
    test_seqs = n_seqs - train_seqs - valid_seqs  # Ensure all data is used

    # Randomly shuffle indices for dataset splitting
    np.random.seed(random_seed)
    indices = np.random.permutation(n_seqs)

    # Create the dataset
    lds_dataset = LinearDynamicalSystem(
        n_latents=n_latents,
        d_dimension=n_neurons,
        sequence_length=sequence_length,
        rotation_angles=rotation_angles,
        noise_variance=noise_variance,
        n_ic=n_ic,
        n_reps=n_reps,
        mean_spike_count=mean_spike_count,
        time_last=time_last,
        softplus_beta=softplus_beta,
    )

    # # split dataset into non-overlapping train, val, test subsets
    # train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    #     lds_dataset,
    #     [train_seqs, valid_seqs, test_seqs],
    #     generator=torch.Generator().manual_seed(random_seed),
    # )
    train_dataset, valid_dataset, test_dataset = sequential_split(
        lds_dataset, [train_seqs, valid_seqs, test_seqs]
    )

    # Split dataset into training, validation, and test DataLoader objects
    lds_dataloader_train = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    lds_dataloader_val = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    lds_dataloader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return lds_dataloader_train, lds_dataloader_val, lds_dataloader_test


if __name__ == "__main__":
    # initialise an LDS dataset
    dataset = LinearDynamicalSystem(
        n_latents=4,
        d_dimension=128,
        rotation_angles=[math.pi / 13, math.pi / 99],
        sequence_length=200,
        noise_variance=0.02,
        n_ic=2,
        n_reps=10,
        mean_spike_count=10,
    )
    print(len(dataset))
