import os
import math
from typing import Optional
import logging

from attrdict import AttrDict

import torch
from torch.distributions import MultivariateNormal, StudentT


logger = logging.getLogger()

class GPPriorSampler:
    def __init__(self, kernel, t_noise=None):

        if isinstance(kernel, str):
            if kernel == "rbf":
                kernel = RBFKernel()
            elif kernel == "rationalquadratic":
                kernel = RationalQuadraticKernel()
            elif kernel == "matern":
                kernel = Matern52Kernel()
            elif kernel == "periodic":
                kernel = PeriodicKernel()
            else:
                raise ValueError(f"Invalid kernel {kernel}")

        self.kernel = kernel
        self.t_noise = t_noise
        
    # bx: 1 * num_points * 1
    def sample(self, x, device):
        # 1 * num_points * num_points
        cov = self.kernel(x)
        mean = torch.zeros(1, x.shape[1], device=device)

        y = MultivariateNormal(mean, cov).rsample().unsqueeze(-1)

        if self.t_noise is not None:
            y += self.t_noise * StudentT(2.1).rsample(y.shape).to(device)

        return y

class RBFKernel:
    def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
        self.sigma_eps = sigma_eps
        self.max_length = max_length
        self.max_scale = max_scale

    # x: batch_size * num_points * dim  [B,N,Dx=1]
    def __call__(self, x):
        length = 0.1 + (self.max_length-0.1) \
                * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
        scale = 0.1 + (self.max_scale-0.1) \
                * torch.rand([x.shape[0], 1, 1], device=x.device)

        # batch_size * num_points * num_points * dim  [B,N,N,1]
        dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length

        # batch_size * num_points * num_points  [B,N,N]
        cov = scale.pow(2) * torch.exp(-0.5 * dist.pow(2).sum(-1)) \
                + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)

        return cov  # [B,N,N]

class RationalQuadraticKernel:
    def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0, mixture_parameter=1.0):
        self.sigma_eps = sigma_eps
        self.max_length = max_length
        self.max_scale = max_scale
        self.mixture_parameter = mixture_parameter
    
    def __call__(self, x):
        length = 0.1 + (self.max_length-0.1) \
                * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
        scale =  0.1 + (self.max_scale-0.1) \
                * torch.rand([x.shape[0], 1, 1], device=x.device)
        
        dist = (x.unsqueeze(-2) - x.unsqueeze(-3))/length

        cov = scale.pow(2) * torch.exp(1 + 0.5 / self.mixture_parameter * dist.pow(2).sum(-1))**(-self.mixture_parameter) \
                + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)
        
        return cov

class Matern52Kernel:
    def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
        self.sigma_eps = sigma_eps
        self.max_length = max_length
        self.max_scale = max_scale

    # x: batch_size * num_points * dim
    def __call__(self, x):
        length = 0.1 + (self.max_length-0.1) \
                * torch.rand([x.shape[0], 1, 1, 1], device=x.device)
        scale = 0.1 + (self.max_scale-0.1) \
                * torch.rand([x.shape[0], 1, 1], device=x.device)

        # batch_size * num_points * num_points
        dist = torch.norm((x.unsqueeze(-2) - x.unsqueeze(-3))/length, dim=-1)

        cov = scale.pow(2)*(1 + math.sqrt(5.0)*dist + 5.0*dist.pow(2)/3.0) \
                * torch.exp(-math.sqrt(5.0) * dist) \
                + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)

        return cov


class PeriodicKernel:
    def __init__(self, sigma_eps=2e-2, max_length=0.6, max_scale=1.0):
        #self.p = p
        self.sigma_eps = sigma_eps
        self.max_length = max_length
        self.max_scale = max_scale

    # x: batch_size * num_points * dim
    def __call__(self, x):
        p = 0.1 + 0.4*torch.rand([x.shape[0], 1, 1], device=x.device)
        length = 0.1 + (self.max_length-0.1) \
                * torch.rand([x.shape[0], 1, 1], device=x.device)
        scale = 0.1 + (self.max_scale-0.1) \
                * torch.rand([x.shape[0], 1, 1], device=x.device)

        dist = x.unsqueeze(-2) - x.unsqueeze(-3)
        cov = scale.pow(2) * torch.exp(\
                - 2*(torch.sin(math.pi*dist.abs().sum(-1)/p)/length).pow(2)) \
                + self.sigma_eps**2 * torch.eye(x.shape[-2]).to(x.device)

        return cov


class GPSampler:
    def __init__(
        self,
        kernel,
        t_noise=None,
        seed: int = 0,
        device: str = "cuda",
        batch_size: int = 16,
        min_num_points: int = 3,
        max_num_points: int = 100,
        x_range: tuple = (-2, 2),
        x_dim: int = 1,
    ):
        if isinstance(kernel, str):
            if kernel == "rbf":
                kernel = RBFKernel()
            elif kernel == "rationalquadratic":
                kernel = RationalQuadraticKernel()
            elif kernel == "matern":
                kernel = Matern52Kernel()
            elif kernel == "periodic":
                kernel = PeriodicKernel()
            else:
                raise ValueError(f"Invalid kernel {kernel}")

        self.kernel = kernel
        self.t_noise = t_noise
        self.rng = torch.Generator(device=device).manual_seed(seed)
        self.seed = seed
        self.device = device
        self.batch_size = batch_size
        self.min_num_points = min_num_points
        self.max_num_points = max_num_points
        self.x_range = x_range
        self.x_dim = x_dim

    def __iter__(self):
        return self

    def __next__(self):
        return self.sample()

    @torch.no_grad()
    def sample(self, num_ctx: Optional[int] = None, num_tar: Optional[int] = None):
        batch = AttrDict()

        if num_ctx is None:
            num_ctx = torch.randint(  # type: ignore
                low=self.min_num_points,
                high=self.max_num_points-self.min_num_points,
                size=[1],
                generator=self.rng,
                device=self.device,
            ).item()  # Nc

        if num_tar is None:
            num_tar = torch.randint(  # type: ignore
                low=self.min_num_points,
                high=self.max_num_points-num_ctx,
                size=[1],
                generator=self.rng,
                device=self.device,
            ).item()  # Nt

        num_points = num_ctx + num_tar  # N = Nc + Nt

        batch.x = self.x_range[0] + (self.x_range[1] - self.x_range[0]) * torch.rand([self.batch_size, num_points, self.x_dim], generator=self.rng, device=self.device)  # [B,N,Dx]
        batch.xc = batch.x[:, :num_ctx]  # [B,Nc,Dx]
        batch.xt = batch.x[:, num_ctx:]  # [B,Nt,Dx]

        # batch_size * num_points * num_points
        cov = self.kernel(batch.x)  # [B,N,N]
        mean = torch.zeros(self.batch_size, num_points, device=self.device)  # [B,N]
        batch.y = MultivariateNormal(mean, cov).sample().unsqueeze(-1)  # [B,N,Dy=1]
        batch.yc = batch.y[:, :num_ctx]  # [B,Nc,1]
        batch.yt = batch.y[:, num_ctx:]  # [B,Nt,1]

        if self.t_noise is not None:
            if self.t_noise == -1:
                t_noise = 0.15 * torch.rand(batch.y.shape, generator=self.rng, device=self.device)
            else:
                t_noise = self.t_noise

            batch.y += t_noise * StudentT(2.1).sample(batch.y.shape).to(self.device)

        return batch
        # {"x": [B,N,D], "xc": [B,Nc,D], "xt": [B,Nt,D],
        #  "y": [B,N,1], "yc": [B,Nt,1], "yt": [B,Nt,1]}


class GPFiniteSampler(GPSampler):
    def __init__(
        self,
        save_dir: str,
        num_batches: int,
        kernel,
        t_noise=None,
        seed: int = 0,
        device: str = "cuda",
        batch_size: int = 16,
        max_num_points: int = 100,
        min_num_points: int = 5,
        x_range: tuple = (-2, 2),
        x_dim: int = 1,
        loop: bool = False,
    ):
        super().__init__(
            kernel=kernel,
            t_noise=t_noise,
            seed=seed,
            device=device,
            batch_size=batch_size,
            max_num_points=max_num_points,
            min_num_points=min_num_points,
            x_range=x_range,
            x_dim=x_dim,
        )

        t_noise_str = f"-tn{t_noise}" if t_noise is not None else ""
        save_name = f"nb{num_batches}-bs{batch_size}-maxp{max_num_points}-minp{min_num_points}{t_noise_str}-seed{seed}.pt"

        self.save_path = os.path.join(save_dir, save_name)
        self.num_batches = num_batches

        if not os.path.exists(self.save_path):
            os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
            logger.info(f"Generating GP data to \"{self.save_path}\"...")
            self.batches = self.save()

        self.batches = torch.load(self.save_path)
        self.loop = loop

    def save(self):
        batches = []
        it = super().__iter__()
        for i in range(self.num_batches):
            batch = next(it)
            for k, v in batch.items():
                batch[k] = v.cpu()
            batches.append(batch)
        torch.save(batches, self.save_path)
        return batches

    def __iter__(self):
        if self.loop:
            rng = torch.Generator(device=self.device).manual_seed(self.seed)
            while True:
                rand_idxs = torch.randperm(len(self.batches), generator=rng, device=self.device)
                for idx in rand_idxs:
                    batch = self.batches[idx]
                    for k, v in batch.items():
                        batch[k] = v.to(self.device)
                    yield batch
        else:
            for batch in self.batches:
                for k, v in batch.items():
                    batch[k] = v.to(self.device)
                yield batch

    def __len__(self):
        return len(self.batches)
