import numpy as np
import torch
from tqdm import tqdm, trange
from data import Batch


class MaterialSchedule:
    def __init__(self, model, uncond_model, t_min=0.0, t_max=1.0):
        self.model = model
        self.uncond_model = uncond_model
        self.t_min = t_min
        self.t_max = t_max
        self.pos_sch = None
        self.el_sch = None

    def gen_random_t(self, batch):
        t = torch.rand((batch.get_batch_size(),)).float().to(
            batch.get_positions().device)
        t = t * (self.t_max - self.t_min) + self.t_min
        return t

    def gen_random_sample(self, sample):
        """Generate a random sample as the initial distribution of reverse process. """
        emb = None
        elements = None
        if self.model.element_embedding is not None:
            emb = self.model.element_embedding.embed(sample.get_elements())
            if self.el_sch is not None:
                emb = torch.randn_like(emb)
            elements = self.model.element_embedding.unembed(emb)
        return sample.randomize_uniform().update_attrs(element_emb=emb, elements=elements)


class FlowPath:
    """Flow matching paths. Using the optimal transport flow."""
    def __init__(self):
        pass

    def interpolate(self, x0, x1, t):
        """Interpolate between source x0 and target x1 at time t."""
        return x0 + t * self.velocity(x0, x1)

    def velocity(self, x0, x1):
        """Compute velocity field for the flow point from x0 to x1"""
        return x1 - x0


class MaterialFlowPath(FlowPath):
    """Flow matching paths for materials with PBC-aware velocity calculation."""
    def __init__(self):
        super().__init__()

    def interpolate(self, x0, x1, t):
        """Interpolate between source and target positions with PBC awareness."""
        velocity = self.velocity(x0, x1)
        return x0.get_positions() + t * velocity

    def velocity(self, x0, x1):
        """Compute velocity field for the flow with PBC awareness."""
        return x0.cal_velocity(x1)


class MaterialFlowSchedule(MaterialSchedule):
    def __init__(self, model, uncond_model):
        super().__init__(model, uncond_model, t_min=0.0, t_max=1.0)
        self.pos_sch = MaterialFlowPath()
        self.el_sch = FlowPath() if model.element_embedding else None

    def forward_process(self, clean_sample, t):
        """
        Sample source-target pairs and interpolate at time t. Clean sample is the x_{t_min}.

        Args:
            t: tensor with shape (batch_size), the diffusion timestep of each batch
        """
        t_shape = [-1] + [1] * (len(clean_sample.get_positions().shape) - 1)
        t_tensor_atom = t[clean_sample.get_batch_indices()].view(t_shape)
        noisy_sample = clean_sample.randomize_uniform()
        noisy_pos = noisy_sample.get_positions()  # x_{t_max}

        flow_pos = self.pos_sch.interpolate(clean_sample, noisy_pos, t_tensor_atom)
        velocity_pos = self.pos_sch.velocity(clean_sample, noisy_pos)

        # Handle elements similarly
        if self.el_sch is not None:
            clean_el_emb = clean_sample.get_element_emb()
            if clean_el_emb is None:
                clean_el_emb = self.model.element_embedding.embed(
                    clean_sample.get_elements())

            noise_el = torch.randn_like(clean_el_emb)
            flow_el = self.el_sch.interpolate(clean_el_emb, noise_el, t_tensor_atom)
            velocity_el = self.el_sch.velocity(clean_el_emb, noise_el)
        else:
            flow_el = None
            velocity_el = None

        flow_sample = clean_sample.update_attrs(
            positions=flow_pos,
            element_emb=flow_el,
            elements=None if flow_el is None else self.model.element_embedding.unembed(flow_el)
        )

        return flow_sample, (velocity_pos, velocity_el)

    def reverse_step(self, sample, t, dt, cond_w=0.0, **kwargs):
        """
        Single ODE integration step.

        Args:
            - t: tensor with shape (batch_size), the diffusion timestep of each batch
            - dt: same shape with t, the timestep delta of each batch

        Returns:
            The predicted previous sample (x_{t-dt}) and clean sample (x_{t_min}).
        """
        batch_idx = sample.get_batch_indices()
        t_shape = [-1] + [1] * (len(sample.get_positions().shape) - 1)
        t_tensor_atom, dt_tensor_atom = t[batch_idx].view(t_shape), dt[batch_idx].view(t_shape)

        # Get predicted velocity from model
        pred_velocity_pos, pred_velocity_els = self.model(sample, t, dt)
        if cond_w > 0.0:
            # Use independent condition guidance
            if self.uncond_model is not None:
                uncond_velocity_pos, uncond_velocity_els = self.uncond_model(sample.null_properties(), t, dt)
            else:
                uncond_velocity_pos, uncond_velocity_els = self.model(sample.null_properties(), t, dt)
            pred_velocity_pos = (1.0 + cond_w) * pred_velocity_pos - cond_w * uncond_velocity_pos
            pred_velocity_els = (1.0 + cond_w) * pred_velocity_els - cond_w * uncond_velocity_els

        # Euler integration: x_{t-dt} = x_t - dt * v_t
        new_positions = sample.get_positions() - (dt_tensor_atom * pred_velocity_pos)
        # Derive clean positions: x_0 = x_t - t * v_t
        clean_positions = sample.get_positions() - (t_tensor_atom * pred_velocity_pos)

        new_element_emb, clean_element_emb = None, None
        if pred_velocity_els is not None:
            new_element_emb = sample.get_element_emb() - (dt_tensor_atom * pred_velocity_els)
            clean_element_emb = sample.get_element_emb() - (t_tensor_atom * pred_velocity_els)

        pred_prev_sample = sample.update_attrs(
            positions=new_positions,
            element_emb=new_element_emb,
            elements=self.model.element_embedding.unembed(new_element_emb) if new_element_emb is not None else None
        )
        pred_clean_sample = sample.update_attrs(
            positions=clean_positions,
            element_emb=clean_element_emb,
            elements=self.model.element_embedding.unembed(clean_element_emb) if clean_element_emb is not None else None
        )

        return pred_prev_sample, pred_clean_sample

    def reverse_process(self, noisy_sample, n_steps, cond_w=0.0, **kwargs):
        """Generate by integrating the flow ODE"""
        timesteps = np.linspace(self.t_max, self.t_min, n_steps + 1)
        dt = timesteps[0] - timesteps[1]

        sample = noisy_sample
        device = noisy_sample.get_positions().device
        batch_size = noisy_sample.get_batch_size()
        trajectory = [sample]

        for i in trange(n_steps, desc='Denoising', leave=False):
            t = timesteps[i]
            sample, _ = self.reverse_step(
                sample,
                torch.tensor((t,)).float().to(device).repeat(batch_size),
                torch.tensor((dt,)).float().to(device).repeat(batch_size),
                cond_w=cond_w
            )
            trajectory.append(sample)

        return trajectory


class SDESchedule:
    """Base class of diffusion noise schedules."""
    def add_noise(self, x, t, noise=None):
        # assuming dimension of t matches first dimension of x
        alpha = self.alpha(t)
        sigma = self.sigma(t)
        if noise is None:
            noise = torch.randn_like(x)
        noisy_x = alpha * x + sigma * noise
        return noisy_x, noise

    def denoise_step(self, x, pred_noise, t, dt, z=None, non_stochastic=False):
        """
        Denoise a single step of the diffusion process.

        Args:
            x (torch.Tensor): The current noisy sample.
            pred_noise (torch.Tensor): The predicted noise component (usually from a score model).
            t (torch.Tensor): The current diffusion timestep.
            dt (torch.Tensor): The difference between the current and previous diffusion timestep.
            z (torch.Tensor, optional): The random noise to add to the sample, usually sampled from a standard normal distribution.
            non_stochastic (bool, optional): Whether to use a non-stochastic denoising method. If true, z is ignored.
        """
        # assuming dimension of t matches first dimension of x
        dt = abs(dt)
        alpha = self.alpha(t)
        sigma = self.sigma(t)
        g = self.g(t)
        g2 = self.g2(t)
        f = self.f(t)

        pred_score = -1. * pred_noise / sigma

        if non_stochastic:
            dx = (f * x - 0.5 * g2 * pred_score) * dt
        else:
            if z is None:
                z = torch.randn_like(x)
            dx = (f * x - g2 * pred_score) * dt + \
                g * z * torch.sqrt(dt)
        prev_x = x - dx
        clean_x = (x - sigma * pred_noise) / alpha
        return prev_x, clean_x


class VarianceExplodingSchedule(SDESchedule):
    """
    Variance exploding stochastic differential equation (SDE) scheduler.
    Paper: Song, Yang, et al. "Score-based generative modeling through stochastic differential equations."
    """
    def __init__(self, sigma_max):
        self.sigma_max = sigma_max
        self.alpha = lambda t: torch.ones_like(t)
        self.f = lambda t: torch.zeros_like(t)
        self.sigma = lambda t: t * sigma_max
        self.g2 = lambda t: 2 * sigma_max**2 * t
        self.g = lambda t: self.g2(t)**0.5
        self.dalpha = lambda t: torch.zeros_like(t)
        self.dsigma = lambda t: sigma_max * torch.ones_like(t)


class VariancePreservingSchedule(SDESchedule):
    """
    Variance preserving stochastic differential equation (SDE) scheduler.
    Paper: Song, Yang, et al. "Score-based generative modeling through stochastic differential equations."
    """
    def __init__(self, sigma_max=1.0, schedule='cosine'):
        self.sigma_max = sigma_max
        if schedule == 'cosine':
            self.alpha = lambda t: torch.cos(torch.pi/2*t)
            self.sigma = lambda t: torch.sin(torch.pi/2*t) * sigma_max
            self.f = lambda t: torch.tan(torch.pi/2*t) * torch.pi * (-0.5)
            self.g2 = lambda t: torch.pi * \
                self.alpha(t) * self.sigma(t) * sigma_max - \
                2 * self.f(t) * self.sigma(t)**2
            self.g = lambda t: self.g2(t)**0.5
        elif schedule == 'linear':
            self.gamma = lambda t: 1 - t
            self.alpha = lambda t: self.gamma(t)**0.5
            self.sigma = lambda t: (1 - self.gamma(t))**0.5 * sigma_max
            self.f = lambda t: 0.5 / (t - 1)
            self.g2 = lambda t: (1 - 2*self.f(t)*t) * sigma_max**2
            self.g = lambda t: self.g2(t)**0.5
        else:
            raise NotImplementedError(f'Unknown noise schedule: {schedule}')


class MaterialSDESchedule(MaterialSchedule):
    """The compound noise schedule for both elements and coordinates of atoms in material samples.
    """

    # TODO: Make it possible to control sigma_max of the element noise, since this controls the point in time, where the element information disappears.
    def __init__(self, model, uncond_model, t_min, t_max, sigma_max_pos=1.0,
                 noise_schedule_el=None, sigma_max_el=1.0):
        """Noise scheduler for materials. Takes care of adding, and removing noise.

        Args:
            t_min (int): Start time of noise schedule
            t_max (int): End time of noise schedule
            sigma_max_pos (float, optional): Maximum variance of position noise. Defaults to 1.0.
            noise_schedule_el (None|str, optional): Noise schedule for the elements. Can be None, "linear" or "cosine".
            If None, elements are not diffused. Defaults to None.
            sigma_max_el (float, optional): Maximum variance of element noise. Defaults to 1.0.
        """
        super().__init__(model, uncond_model, t_min, t_max)

        self.pos_sch = VarianceExplodingSchedule(sigma_max_pos)
        self.el_sch = None

        if noise_schedule_el is not None:
            if model.element_embedding is not None:
                self.el_sch = VariancePreservingSchedule(
                    sigma_max_el, noise_schedule_el)
            else:
                raise Exception(
                    "Diffusion of elements is only possible in combination with an element embedding")

    def gen_infer_timesteps(self, n_steps, t_min=None, t_max=None):
        """Generates the time stamps for the denoising. t_min is excluded."""
        t_min = self.t_min if t_min is None else t_min
        t_max = self.t_max if t_max is None else t_max
        return np.linspace(1, 0, n_steps + 1)[:-1] * (t_max - t_min) + t_min

    def forward_process(self, clean_sample, t):
        """Add noise to a batch of clean samples based on the given timestamps.

        Args:
            - clean_sample (Sample): a batch of clean material samples.
            - t: a batch of timestamps, with shape (batch_size).

        Returns:
            - Sample: the batch of noisy samples.
        """
        t_shape = [-1] + [1] * (len(clean_sample.get_positions().shape) - 1)
        t_per_at = t[clean_sample.get_batch_indices()].view(t_shape)
        noise_pos = torch.randn_like(clean_sample.get_positions())
        noise_pos = clean_sample.remove_mean(noise_pos)
        noisy_pos, noise_pos = self.pos_sch.add_noise(
            clean_sample.get_positions(), t_per_at, noise=noise_pos)

        if self.model.element_embedding is not None:
            clean_el_emb = clean_sample.get_element_emb()
            if clean_el_emb is None:
                clean_el_emb = self.model.element_embedding.embed(
                    clean_sample.get_elements())
            if self.el_sch is not None:
                noisy_els, noise_els = self.el_sch.add_noise(clean_el_emb, t_per_at)
            else:
                noise_els = None
                noisy_els = clean_el_emb  # we don't noise the elements
        else:
            noise_els = None
            noisy_els = None

        noisy_sample = clean_sample.update_attrs(
            positions=noisy_pos,
            element_emb=noisy_els,
            elements=None if self.el_sch is None else self.model.element_embedding.unembed(
                noisy_els)
        )

        return noisy_sample, (noise_pos, noise_els)

    def reverse_step(self, noisy_sample, t, dt, cond_w=0.0, non_stochastic=False,
        guidance_fn=None, num_cand=16):
        """Predict the sample from the previous timestep by reversing the SDE.

        Args:
            - noisy_sample (Sample): the batch of samples in the current diffusion timesteps.
            - t: tensor with shape (batch_size), the diffusion timestep of each batch
            - dt: same shape with t, the timestep delta of each batch
            - non_stochastic (bool, optional): If False, the stochastic denoising method is used. Defaults to False.

        Returns:
            - Sample: computed sample (x_{t-dt}) of previous timestep.
            - Sample: predicted denoised sample (x_{t_min}) based on the model output from the current timestep.
        """
        batch_idx = noisy_sample.get_batch_indices()
        t_shape = [-1] + [1] * (len(noisy_sample.get_positions().shape) - 1)
        t_tensor_atom, dt_tensor_atom = t[batch_idx].view(t_shape), dt[batch_idx].view(t_shape)

        pred_noise_pos, pred_noise_els = self.model(noisy_sample, t, dt)
        if cond_w > 0:
            # Use independent condition guidance
            if self.uncond_model is not None:
                uncond_noise_pos, uncond_noise_els = self.uncond_model(noisy_sample.null_properties(), t, dt)
            else:
                uncond_noise_pos, uncond_noise_els = self.model(noisy_sample.null_properties(), t, dt)
            pred_noise_pos = (1.0 + cond_w) * pred_noise_pos - cond_w * uncond_noise_pos
            pred_noise_els = (1.0 + cond_w) * pred_noise_els - cond_w * uncond_noise_els

        if guidance_fn is None:
            # Denoising of positions.
            z_pos = torch.randn_like(noisy_sample.get_positions())
            z_pos = noisy_sample.remove_mean(z_pos)
            new_positions, clean_positions = self.pos_sch.denoise_step(
                noisy_sample.get_positions(), pred_noise_pos, t_tensor_atom, dt_tensor_atom, z_pos, non_stochastic=non_stochastic)

            # Denoising of elements.
            new_element_emb, clean_element_emb = None, None
            if self.el_sch is not None:
                z_el = torch.randn_like(noisy_sample.get_element_emb())
                new_element_emb, clean_element_emb = self.el_sch.denoise_step(
                    noisy_sample.get_element_emb(), pred_noise_els, t_tensor_atom, dt_tensor_atom, z_el, non_stochastic=non_stochastic)

            pred_prev_sample = noisy_sample.update_attrs(
                positions=new_positions,
                elements=self.model.element_embedding.unembed(
                    new_element_emb) if self.el_sch is not None else None,
                element_emb=new_element_emb if self.el_sch is not None else None
            )
            pred_clean_sample = noisy_sample.update_attrs(
                positions=clean_positions,
                elements=self.model.element_embedding.unembed(
                    clean_element_emb) if self.el_sch is not None else None,
                element_emb=clean_element_emb if self.el_sch is not None else None
            )
        else:
            with torch.no_grad():
                cand_prev_samples, cand_clean_samples, cand_scores = [], [], []
                # Use stochastic control guidance to select the optimal candidate
                for k in range(num_cand):
                    z_pos = torch.randn_like(noisy_sample.get_positions())
                    z_pos = noisy_sample.remove_mean(z_pos)
                    new_positions, clean_positions = self.pos_sch.denoise_step(
                        noisy_sample.get_positions(), pred_noise_pos, t_tensor_atom, dt_tensor_atom, z_pos, non_stochastic=non_stochastic)

                    # Denoising of elements.
                    new_element_emb, clean_element_emb = None, None
                    if self.el_sch is not None:
                        z_el = torch.randn_like(noisy_sample.get_element_emb())
                        new_element_emb, clean_element_emb = self.el_sch.denoise_step(
                            noisy_sample.get_element_emb(), pred_noise_els, t_tensor_atom, dt_tensor_atom, z_el, non_stochastic=non_stochastic)

                    cand_prev_sample = noisy_sample.update_attrs(
                        positions=new_positions,
                        elements=self.model.element_embedding.unembed(
                            new_element_emb) if self.el_sch is not None else None,
                        element_emb=new_element_emb if self.el_sch is not None else None
                    )
                    cand_clean_sample = noisy_sample.update_attrs(
                        positions=clean_positions,
                        elements=self.model.element_embedding.unembed(
                            clean_element_emb) if self.el_sch is not None else None,
                        element_emb=clean_element_emb if self.el_sch is not None else None
                    )
                    cand_prev_samples.append(cand_prev_sample)
                    cand_clean_samples.append(cand_clean_sample)
                    cand_scores.append(guidance_fn(cand_prev_sample, cand_clean_sample))  # (batch_size) * num_cand

                best_idx = torch.argmin(torch.stack(cand_scores, 1), 1)  # (batch_size, num_cand) -> (batch_size), the best candidate for each batch
                pred_prev_sample, pred_clean_sample = [], []
                for batch in range(noisy_sample.get_batch_size()):
                    pred_prev_sample.append(cand_prev_samples[best_idx[batch]].samples[batch])
                    pred_clean_sample.append(cand_clean_samples[best_idx[batch]].samples[batch])

                pred_prev_sample = Batch(pred_prev_sample)
                pred_clean_sample = Batch(pred_clean_sample)

        return pred_prev_sample, pred_clean_sample

    def reverse_process(self, noisy_sample, n_steps, t_min=None, t_max=None, cond_w=0.0,
        guidance_fn=None, num_cand=16, non_stochastic=False, final_step=True):
        """Perform the complete denoising of a given sample/batch

        Args:
            - noisy_sample (Sample/Batch): Sample to denoise
            - n_steps (int): Number of denoising steps
            - t_min (int, optional): Final time of the denoising. If none, the schedulers default is used.
            - t_max (int, optional): First time of the denoising. If none, the schedulers default is used.
            - non_stochastic (bool, optional): If False, the stochastic denoising method is used. Defaults to False.
            - final_step (bool, optional): Wether a final denoising step going from t_min to 0 is added at the end of the trajectory. Defaults to True.

        Returns:
            - list(Sample/Batch): The complete denoising trajectory
        """
        t_min = self.t_min if t_min is None else t_min
        t_max = self.t_max if t_max is None else t_max
        timesteps = self.gen_infer_timesteps(n_steps, t_min, t_max)
        dt = abs(timesteps[0] - timesteps[1]) if n_steps > 1 else timesteps[0]
        device = noisy_sample.get_positions().device
        batch_size = noisy_sample.get_batch_size()

        sample = noisy_sample
        trajectory = [sample]

        for i, t in tqdm(enumerate(timesteps), leave=False, total=n_steps, desc='Denoising'):
            # decrease t
            pred_prev_sample, _ = self.reverse_step(
                noisy_sample=sample,
                t=torch.tensor((t,)).float().to(device).repeat(batch_size),
                dt=torch.tensor((dt,)).float().to(device).repeat(batch_size),
                cond_w=cond_w,
                guidance_fn=guidance_fn,
                num_cand=num_cand,
                non_stochastic=non_stochastic
            )
            sample = pred_prev_sample
            trajectory.append(sample)

        if final_step:
            t_min_batch = torch.tensor((t_min,)).float().to(device).repeat(batch_size)
            pred_prev_sample, _ = self.reverse_step(
                noisy_sample=sample,
                t=t_min_batch,
                dt=t_min_batch,
                cond_w=cond_w,
                guidance_fn=guidance_fn,
                num_cand=num_cand,
                non_stochastic=non_stochastic
            )
            sample = pred_prev_sample
            trajectory.append(sample)

        return trajectory
