
import torch
import numpy as np
from torch_scatter import scatter_mean, scatter_add




def sample_zero_centered_gaussian(size, device, segment_ids):
    assert len(size) == 2  # TODO check this
    x = torch.randn(size, device=device)
    seg_means = scatter_mean(x, segment_ids, dim=0)
    mean_for_each_segment = seg_means.index_select(0, segment_ids)
    x = x - mean_for_each_segment
    return x





def clip_noise_schedule(alphas2, clip_value=0.001):

    # alphas2 is the cumprod of alpha_t (alphas_step)
    alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)

    alphas_step = alphas2[1:] / alphas2[:-1]

    alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.0)
    alphas2 = np.cumprod(alphas_step, axis=0)

    return alphas2


def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1):

    steps = timesteps + 2
    x = np.linspace(0, steps, steps)
    f_t = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    # alpha_t^{-} = f(t)/f(0), where f(t) = cos^2((t/T + s)/(1+s) * pi/2)
    alphat_cumprod = f_t / f_t[0]
    betas = 1 - (alphat_cumprod[1:] / alphat_cumprod[:-1])
    # beta_t =1 -  alphat_cumprod/alpha(t-1)_cumprod
    betas = np.clip(betas, a_min=0, a_max=0.999)
    alphas = 1.0 - betas

    alphas_cumprod = np.cumprod(alphas, axis=0)

    if raise_to_power != 1:
        alphas_cumprod = np.power(alphas_cumprod, raise_to_power)

    return alphas_cumprod


def polynomial_schedule(timesteps: int, s=1e-4, power=3.0):

    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas2 = (1 - np.power(x / steps, power)) ** 2

    alphas2 = clip_noise_schedule(alphas2, clip_value=0.001)

    precision = 1 - 2 * s

    alphas2 = precision * alphas2 + s

    return alphas2


class PredefinedNoiseSchedule(torch.nn.Module):


    def __init__(self, noise_schedule, timesteps, precision):
        super(PredefinedNoiseSchedule, self).__init__()
        self.timesteps = timesteps

        if noise_schedule == "cosine":
            alphas2 = cosine_beta_schedule(timesteps)
        elif "polynomial" in noise_schedule:
            splits = noise_schedule.split("_")
            assert len(splits) == 2
            power = float(splits[1])
            alphas2 = polynomial_schedule(timesteps, s=precision, power=power)
        else:
            raise ValueError(noise_schedule)

        print("alphas2", alphas2)

        sigmas2 = 1 - alphas2

        log_alphas2 = np.log(alphas2)
        log_sigmas2 = np.log(sigmas2)

        log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2

        print("gamma", -log_alphas2_to_sigmas2)

        self.gamma = torch.nn.Parameter(
            torch.from_numpy(-log_alphas2_to_sigmas2).float(), requires_grad=False
        )

    def forward(self, t):
        t_int = torch.round(t * self.timesteps).long()
        return self.gamma[t_int]



class Probability_path_base:
    def __init__(self, *args, **kwargs):
        super(Probability_path_base, self).__init__(*args, **kwargs)

    def sample_x_t(self, z, x, t):
        """
        sample from the distribution
        """
        raise NotImplementedError

    def target_field(self, z, x, t):
        """
        return a vector field on that transport plan
        """
        raise NotImplementedError
    
    def M_para(self,t):

        return 1.0 




    

