import os
import logging
from time import time

import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from omegaconf import DictConfig, OmegaConf
from torch import nn, tensor, Tensor
import lightning
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.distributions import MultivariateNormal

import ite

import gfm
from gfm import (
    GaugeMap,
    unit_cube_mirror_map, unit_cube_dual_map,
    unit_ball_mirror_map, unit_ball_dual_map,
    odeint_reflect,
    cube_reflect, ball_reflect,
    cube_project, ball_project,
    MlpVelocityField,
    HyperBallUniform, box_uniform,
    maximum_mean_discrepancy, ConstrainedSet,
    TruncatedDistribution,
)

logger = logging.getLogger(__name__)


class PriorDataset(Dataset):
    def __init__(self, distribution, length=10000):
        """
        Args:
            distribution (torch.distributions.Distribution): A PyTorch distribution
            length (int): Virtual length (number of samples to simulate)
        """
        self.distribution = distribution
        self.length = length  # Pretend we have this many samples

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Ignore idx; just return a fresh sample
        sample = self.distribution.sample([1])
        return sample


class GfmExampleBase(lightning.LightningModule):
    """
    Base class for implementing GFM examples.
    """

    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.velocity = MlpVelocityField(
            cfg.example.dimension,
            cfg.velocity.width,
            cfg.velocity.depth,
            cfg.velocity.activation
        )
        self.save_hyperparameters()

        #### The following buffers are for DDPM only ####
        # Precompute forward process constants
        if self.cfg.method.name == "ddpm":
            beta = torch.linspace(1e-4, 0.02, self.cfg.method.horizon)
            alpha = 1. - beta
            self.register_buffer('beta', beta)
            self.register_buffer('alpha', alpha)
            self.register_buffer('alpha_bar', torch.cumprod(alpha, dim=0))
            self.register_buffer('sqrt_alpha_bar', torch.sqrt(self.alpha_bar))
            self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - self.alpha_bar))

    def get_loss(self):
        loss = getattr(self, "_loss", None)
        if loss is None:
            loss = getattr(nn, self.cfg.train.loss)()
            setattr(self, "_loss", loss)
        return loss

    def get_prior(self) -> torch.distributions.Distribution:
        dist = getattr(self, "_prior", None)
        if dist is None:
            match self.cfg.method.transform:
                case "L2":
                    dist = HyperBallUniform(self.cfg.example.dimension, scale=self.cfg.method.scale)
                case "L_inf":
                    dist = box_uniform(self.cfg.example.dimension, scale=self.cfg.method.scale)
                case "mirror_2" | "mirror_inf":
                    dist = MultivariateNormal(
                        torch.zeros(self.cfg.example.dimension),
                        self.cfg.method.scale * torch.eye(self.cfg.example.dimension)
                    )
                case None:
                    match self.cfg.method.name:
                        case "vanilla" | "ddpm":
                            dist = MultivariateNormal(
                                torch.zeros(self.cfg.example.dimension),
                                self.cfg.method.scale * torch.eye(self.cfg.example.dimension)
                            )
                        case "reflect" | "project":
                            dist = TruncatedDistribution(
                                MultivariateNormal(
                                    self.get_interior_point() if self.cfg.method.prior_center is None
                                    else tensor(self.cfg.method.prior_center),
                                    self.cfg.method.scale * torch.eye(self.cfg.example.dimension)
                                ),
                                self.get_domain(),
                                self.device,
                            )
            setattr(self, "_prior", dist)
        return dist

    def get_domain(self) -> gfm.ConstrainedSet:
        domain = getattr(self, "_domain", None)
        if domain is None:
            domain, ip = self._init_domain()
            setattr(self, "_domain", domain)
            setattr(self, "_ip", ip)
        return domain

    def get_interior_point(self) -> Tensor:
        ip = getattr(self, "_ip", None)
        if ip is None:
            domain, ip = self._init_domain()
            setattr(self, "_domain", domain)
            setattr(self, "_ip", ip)
        return ip

    def get_data(self) -> Tensor:
        """
        Returns the true data samples.

        If the data file specified by {out_prefix}/{example.samples.file} does not exist,
        data samples are generated by the `_init_data` method.
        :return: Data samples, Tensor of N * dim.
        """
        samples = getattr(self, "_data", None)
        if samples is None:
            data_file = os.path.join(self.cfg.out_prefix, self.cfg.example.data_file)
            if os.path.exists(data_file) and os.path.isfile(data_file):
                samples = torch.load(data_file, map_location=self.device)
            else:
                samples = self._init_data(self.cfg.example.n_samples)
                torch.save(samples, data_file)
            setattr(self, "_data", samples)
        return samples

    def transform(self, xs: Tensor) -> Tensor:
        transform = getattr(self, "_transform", None)
        if transform is None:
            self._init_transformation()
            transform = getattr(self, "_transform", None)
        return transform(xs)

    def inverse_transform(self, zs: Tensor) -> Tensor:
        inverse_transform = getattr(self, "_inverse_transform", None)
        if inverse_transform is None:
            self._init_transformation()
            inverse_transform = getattr(self, "_inverse_transform", None)
        return inverse_transform(zs)

    def get_reflect_fn(self):
        rf = getattr(self, "_reflect_fn", None)
        if rf is None:
            match self.cfg.method.name:
                case "vanilla" | "gauge_vanilla.yaml" | "gauge_mirror":
                    rf = None
                case "reflect":
                    rf = self._refect_rf()
                case "project":
                    rf = self._project_rf()
                case "gauge_reflect":
                    rf = cube_reflect if self.cfg.method.transform == "L_inf" else ball_reflect
                case "gauge_project.yaml":
                    rf = cube_project if self.cfg.method.transform == "L_inf" else ball_project
                case _:
                    raise NotImplementedError
            setattr(self, "_reflect_fn", rf)
        return rf

    def _reflect_rf(self):
        raise NotImplementedError

    def _project_rf(self):
        raise NotImplementedError

    def _init_domain(self) -> tuple[ConstrainedSet, Tensor]:
        raise NotImplementedError

    def _init_data(self, n: int) -> Tensor:
        raise NotImplementedError

    def _init_transformation(self):
        transform = getattr(self, "_transform", None)
        inverse_transform = getattr(self, "_inverse_transform", None)
        if transform is None:
            match self.cfg.method.transform:
                # TODO: scale the gauge map
                case "L2":
                    gauge_map = GaugeMap(self.get_domain(), self.get_interior_point(), "ball")
                    transform = lambda x: gauge_map.to_disk(x)
                    inverse_transform = lambda x: gauge_map.from_disk(x)
                case "L_inf":
                    gauge_map = GaugeMap(self.get_domain(), self.get_interior_point(), "cube")
                    transform = lambda x: gauge_map.to_disk(x)
                    inverse_transform = lambda x: gauge_map.from_disk(x)
                case "mirror_2":
                    gauge_map = GaugeMap(self.get_domain(), self.get_interior_point(), "ball")
                    transform = lambda x: unit_ball_mirror_map(gauge_map.to_disk(x))
                    inverse_transform = lambda x: gauge_map.from_disk(unit_ball_dual_map(x))
                case "mirror_inf":
                    gauge_map = GaugeMap(self.get_domain(), self.get_interior_point(), "cube")
                    transform = lambda x: unit_cube_mirror_map(gauge_map.to_disk(x))
                    inverse_transform = lambda x: gauge_map.from_disk(unit_cube_dual_map(x))
                case None:
                    transform = lambda x: x
                    inverse_transform = lambda x: x
            setattr(self, "_transform", transform)
            setattr(self, "_inverse_transform", inverse_transform)

    def configure_optimizers(self):
        opti = getattr(torch.optim, self.cfg.train.optimizer)(
            self.parameters(),
            **OmegaConf.to_container(self.cfg.train.optimizer_args)
        )
        sche = torch.optim.lr_scheduler.StepLR(
            opti,
            **OmegaConf.to_container(self.cfg.train.scheduler_args)
        )
        return {
            "optimizer": opti,
            "lr_scheduler": {
                "scheduler": sche,
                "interval": "step",
                "frequency": 1,
            }
        }

    def train_dataloader(self):
        data = self.get_data()
        training_data = self.transform(data)
        return [
            DataLoader(
                TensorDataset(training_data.cpu()),
                batch_size=self.cfg.train.batch_size,
                shuffle=True,
                generator=torch.Generator(device=self.device),
                num_workers=1,
            ),
            DataLoader(
                PriorDataset(self.get_prior(), training_data.shape[0]),
                shuffle=False,
                batch_size=self.cfg.train.batch_size,
                num_workers=self.cfg.train.get("num_workers", 0),
            )
        ]

    def test_dataloader(self):
        return [
            DataLoader(
                TensorDataset(tensor([self.cfg.test.n_gen])),
                batch_size=1, shuffle=False
            )
            for _ in range(self.cfg.test.repeats)
        ]

    def on_train_start(self) -> None:
        self.velocity.to(self.device)

    def on_test_start(self) -> None:
        self.velocity.to(self.device)

    def training_step(self, batch, batch_idx):
        if self.cfg.method.name == "ddpm": return self.ddpm_training_step(batch)
        z_1 = batch[0]
        z_0 = batch[1]
        t = torch.rand(len(z_1), 1).to(z_1)
        z_t = (1 - t) * z_0 + t * z_1
        dz_t = z_1 - z_0
        return self.get_loss()(self.velocity(t, z_t), dz_t)

    @torch.no_grad()
    def sample(self, n_samples: int, n_steps: int) -> Tensor:
        start = time()
        z_0 = self.get_prior().sample([n_samples]).to(self.device)
        prior_time = time() - start
        t = torch.linspace(0, 1, n_steps).to(self.device)
        start = time()
        z_1 = (self.integrate_ddpm(z_0) if self.cfg.method.name == "ddpm" else
               odeint_reflect(self.velocity, z_0, t, self.get_reflect_fn())[-1])
        integral_time = time() - start
        start = time()
        x_1 = self.inverse_transform(z_1)
        transform_time = time() - start
        self.log("prior_time", prior_time)
        self.log("integral_time", integral_time)
        self.log("transform_time", transform_time)
        return x_1

    @torch.no_grad()
    def test_step(self, *args, **kwargs) -> STEP_OUTPUT:
        co = ite.cost.BDKL_KnnK()
        x_1 = self.sample(self.cfg.test.n_gen, self.cfg.test.n_steps)
        data = self.get_data()
        kl = co.estimation(x_1, data)
        mmd = maximum_mean_discrepancy(x_1, data)
        fea = self.get_domain().check_feasibility_v(x_1).sum()
        self.log("kl", kl)
        self.log("mmd", mmd)
        self.log("feasible", tensor([fea]))

        if self.cfg.test.get("save_samples", False):
            i = 0
            while os.path.exists(f"gen_samples_{i}.pt"):
                i += 1
            torch.save(x_1, f"gen_samples_{i}.pt")

        return {
            "loss": 0,
            "kl": kl,
            "mmd": mmd,
            "feasible": fea,
        }

    ###### THE FOLLOWING METHODS ARE FOR DDPM ONLY ######

    def q_sample(self, t: Tensor, x_0: Tensor, noise: Tensor) -> Tensor:
        """
        Sample from the forward process at time t.
        :param t: Time step, shape (N, 1)
        :param x_0: Current sample, shape (N, dim)
        :param noise: Noise, shape (N, dim)
        :return: Sample at time t, shape (N, dim)
        """
        sqrt_ab = self.sqrt_alpha_bar[t].unsqueeze(-1).to(t)
        sqrt_1mab = self.sqrt_one_minus_alpha_bar[t].unsqueeze(-1).to(t)
        return sqrt_ab * x_0 + sqrt_1mab * noise

    def p_sample(self, t: Tensor, x_t: Tensor) -> Tensor:
        """
        Sample from the reverse process at time t.
        :param t: Time step, shape (N, 1)
        :param x_t: Current sample, shape (N, dim)
        :return: Sample at time t-1, shape (N, dim)
        """
        pred_noise = self.velocity(t / self.cfg.method.horizon, x_t)
        beta_t = self.beta[t].unsqueeze(-1).to(t)
        alpha_t = self.alpha[t].unsqueeze(-1).to(t)
        alpha_bar_t = self.alpha_bar[t].unsqueeze(-1).to(t)

        mean = (1 / torch.sqrt(alpha_t)) * (x_t - beta_t / torch.sqrt(1 - alpha_bar_t) * pred_noise)
        if t[0] == 0:
            return mean
        noise = torch.randn_like(x_t)
        return mean + torch.sqrt(beta_t) * noise

    def ddpm_training_step(self, batch):
        x0 = batch[0]
        noise = batch[1]
        t = torch.randint(0, self.cfg.method.horizon, (x0.size(0),), device=x0.device)
        xt = self.q_sample(t, x0, noise)
        pred_noise = self.velocity(t / self.cfg.method.horizon, xt)
        loss = self.get_loss()(pred_noise, noise)
        return loss

    def integrate_ddpm(self, x0: Tensor) -> Tensor:
        x = x0
        for t in reversed(range(self.cfg.method.horizon)):
            t_batch = torch.full((x0.shape[0],), t, dtype=torch.long, device=x0.device)
            x = self.p_sample(t_batch, x)
        return x
