import numpy as np
import torch
from einops import rearrange
from scipy.interpolate import CubicSpline

from margflow.datasets.dataset_abstracts import HybridDataset
from margflow.datasets.target_distribution import TargetMixture


class MixtureOfGaussianTime(HybridDataset):
    def __init__(self, args):
        super(MixtureOfGaussianTime, self).__init__(args)
        assert args.x_dim == 2
        self._initialize_means(args)
        self.sigma = torch.tensor(self.args.mog_sigma, device=args.device)
        self.dataset_suffix += f"_condmix_nmog{self.args.n_mog:d}_tsigma{self.args.mog_sigma}"
        self.D = self.args.x_dim

    def _initialize_means(self, args):
        # Parameters
        self.t0, self.t1 = 0, 1
        num_steps = 50  # 50
        step_size = 0.3  # 0.3
        smoothing_window = 5  # choose odd number for simplicity
        buffer_length = smoothing_window // 2
        assert smoothing_window % 2 != 0  # check window has odd length
        smooth_trajectory = lambda traj, window: np.convolve(
            traj, np.ones(window) / window, mode="valid"
        )
        t = np.linspace(self.t0, self.t1, num_steps)

        self.means_trajectories = []
        offset, scale = 0, 5
        starting_position = (np.random.rand(2, self.args.n_mog) - 0.5) * scale + offset
        for i in range(self.args.n_mog):
            dx, dy = np.random.normal(
                0, step_size, num_steps + buffer_length * 2
            ), np.random.normal(0, step_size, num_steps + buffer_length * 2)
            # dx, dy = dx + starting_position[0,i], dy + starting_position[1,i]
            x, y = np.cumsum(dx), np.cumsum(dy)
            x, y = x + starting_position[0, i], y + starting_position[1, i]
            x_smooth, y_smooth = smooth_trajectory(x, smoothing_window), smooth_trajectory(
                y, smoothing_window
            )
            x_spline, y_spline = CubicSpline(t, x_smooth), CubicSpline(t, y_smooth)
            self.means_trajectories.append((x_spline, y_spline))

    def sample(
        self,
        n_samples,
        data_type: str = "train",
        n_timesteps=50,
        ordered=False,
    ) -> torch.Tensor:
        samples_per_timepoint = n_samples // n_timesteps
        timesteps = (
            np.linspace(0, 1, n_timesteps)
            if ordered
            else np.random.rand(n_timesteps) * (self.t1 - self.t0) + self.t0
        )
        samples = []
        for timestep in timesteps:
            current_means = []
            for x_spline, y_spline in self.means_trajectories:
                current_means.append(np.c_[x_spline(timestep), y_spline(timestep)].squeeze())
            current_means = np.array(current_means)
            current_means = torch.from_numpy(current_means).float().to(self.args.device)
            target = TargetMixture(
                n_dim=self.args.x_dim,
                n_target_modes=self.args.n_mog,
                sigma=self.sigma,
                means=current_means,
                bounds=self.args.bounds,
                device=self.args.device,
                dtype=torch.float32,
            )
            samples.append(target.sample(n_samples=samples_per_timepoint))
        # timesteps = torch.from_numpy(timesteps.repeat(samples_per_timepoint)).float().to(self.args.device)[:,None] # n_samples x 1
        timesteps = (
            torch.from_numpy(timesteps).float().to(self.args.device)[:, None]
        )  # n_samples x 1
        # samples = torch.stack(samples, dim=1) # n_samples x timesteps x dim
        samples = torch.stack(samples, dim=0)  # timesteps (context) x n_samples x dim
        # TODO: decide on the format of conditional data
        #  either (n_cond, n_samples_per_cond, ndim) - requires each condition to have the same number of samples
        #  or (n_samples, ndim) where each samples has its own condition (not necessarily shared with other samples)
        #  for the second option use following code or comment it out otherwise:

        # samples = rearrange(samples, "n_c n_s d -> (n_c n_s) d")
        # timesteps = timesteps.expand(timesteps.shape[0], n_samples//n_timesteps) #
        # timesteps = rearrange(timesteps, "n_t n_s -> (n_t n_s) 1")

        return samples, timesteps

    def log_prob(self, x):
        raise NotImplementedError

        # if isinstance(x, np.ndarray):
        #     x = torch.from_numpy(x).float().to(self.args.device)
        # logp = self.target.log_prob(x)
        # return logp
