from pathlib import Path
from typing import Literal

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta

from margflow.abstract_model import AbstractModel
from margflow.model_utils import (
    sample_gaussian_mixture,
    log_prob_gaussian_mixture,
    log_prob_mixture_betas,
    sample_mixture_betas,
)
from margflow.nn.cond_mlp import Hypernetwork, CondFourier
from margflow.nn.mlp import MLP, MLPSkip
from margflow.utils.math_utils import sample_quasi_random_mixture


class LinearContext(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)

    def forward(self, x, context=None):
        return super().forward(x)


def build_mlp_margflow(
    z_dim: int,
    hid_dim: int,
    n_layers: int,
    x_dim: int,
    conditional_network: str = "unconditional",
    cond_dim: int = None,
    fourier_sigma: float = 0.05,
    n_fourier_features: int = 4096,
    dropout: float = 0.0,
    skip_connection: bool = True,
    device: str = "cuda",
) -> MLP | Hypernetwork | CondFourier:
    center_init = True
    mult_output = 1.0

    if conditional_network == "unconditional":
        n_fourier_features = None  # 4096
        fourier_sigma = None  # 1
        # network = LinearContext(z_dim, x_dim).to(device)
        # network = MLPSkip(
        #     input_dim=z_dim,
        #     hid_dim=hid_dim,
        #     n_layers=n_layers,
        #     output_dim=x_dim,
        # ).to(device)
        sigmoid_output = False
        learnable_range = False
        network = MLP(
            in_dim=z_dim,
            hid_dim=hid_dim,
            n_layers=n_layers,
            out_dim=x_dim,
            dropout=dropout,
            center_init=center_init,
            mult_output=mult_output,
            fourier_sigma=fourier_sigma,
            n_fourier_features=n_fourier_features,
            learnable_range=learnable_range,
            sigmoid_output=sigmoid_output,
        ).to(device)
    elif conditional_network == "cond_hypernet":
        hypernet_hid_dim = 256
        hypernet_n_layers = 3
        skip_connection = True
        network = Hypernetwork(
            in_dim=z_dim,
            hid_dim=hid_dim,
            n_layers=n_layers,
            out_dim=x_dim,
            hypernet_in_dim=cond_dim,
            hypernet_hid_dim=hypernet_hid_dim,
            hypernet_n_layers=hypernet_n_layers,
            fourier_dim=n_fourier_features,
            dropout=dropout,
            skip_connection=skip_connection,
        ).to(device)
    elif conditional_network == "cond_fourier":
        network = CondFourier(
            in_dim=z_dim,
            out_dim=x_dim,
            cond_dim=cond_dim,
            hidden_dims=[hid_dim] * n_layers,
            fourier_dim=n_fourier_features,
            fourier_sigma=fourier_sigma,
            fourier_dim_x=n_fourier_features,
            dropout=dropout,
        ).to(device)
    else:
        raise ValueError(f"Unknown conditional network value: {conditional_network}")

    return network


class MarginalFlow(AbstractModel):
    def __init__(
        self,
        x_dim: int,  # dimensionality of data/density
        z_dim: int,  # dimensionality of base distribution
        n_layers: int = 5,  # number of layers in the neural network
        hid_dim: int = 256,  # number of neurons per hidden unit
        mixture_distribution: Literal["gaussian", "beta"] = "gaussian",  # mixture distribution
        base_distribution: Literal["mog", "uniform", "betas", "quasi_random"] = "mog",
        use_trainable_means: bool = False,  # whether to train directly the means and avoid training the network
        n_base_means: int = 1,  # number of means in the base distributions (if base_sidstribution="mog")
        signature: str = None,  # identifier of dataset and training hyperparams
        script_path: Path = None,  # path to script directory
        conditional_network: str = "unconditional",  # chosen from "unconditional", "cond_hypernet", "cond_fourier"
        cond_dim: int = None,  # if dataset is conditional, dimensionality of conditioning variable
        dropout: float = 0.0,  # whether to use dropout in the neural network
        log_sigma_init: float = 0.0,  # initialization of (log) sigma tensor (trainable)
        isotropic_sigma: bool = True,  # whether to use isotropic Gaussians, else digonal. In both cases sigma is learnable
        fourier_sigma: float = 0.01,  # frequency of fourier features in the network (used if cond_dim is not None)
        skip_connection: bool = False,  # whether to use skip connections in the conditional networks
        device: str = "cuda",  # device on which model is initialized and trained
        dtype: torch.dtype = torch.float32,  # dtype of tensors
    ):
        super(MarginalFlow, self).__init__(
            model_name="marginal_flow",
            x_dim=x_dim,
            script_path=script_path,
            signature=signature,
            device=device,
            dtype=dtype,
        )
        self.base_distribution = base_distribution
        self.mixture_distribution = mixture_distribution
        self.n_base_means = n_base_means
        self.use_trainable_means = use_trainable_means
        self.n_layers = n_layers
        self.hid_dim = hid_dim
        self.dropout = dropout
        self.z_dim = z_dim
        self.dtype = dtype
        self.cond_dim = cond_dim
        self.fourier_sigma = fourier_sigma
        self.isotropic_sigma = isotropic_sigma
        self.skip_connection = skip_connection

        self.network = build_mlp_margflow(
            z_dim=z_dim,
            hid_dim=hid_dim,
            n_layers=n_layers,
            x_dim=x_dim,
            conditional_network=conditional_network,
            cond_dim=cond_dim,
            fourier_sigma=fourier_sigma,
            dropout=dropout,
            skip_connection=skip_connection,
            device=device,
        )

        log_sigma = torch.tensor([log_sigma_init], device=device, dtype=dtype)
        log_bandwidth = torch.tensor([log_sigma_init], device=device, dtype=dtype)

        if self.isotropic_sigma:
            self.log_sigma = nn.Parameter(log_sigma, requires_grad=True)
        else:
            self.log_sigma = nn.Parameter(
                torch.ones(self.x_dim, device=self.device) * log_sigma, requires_grad=True
            )

        self.log_bandwidth = nn.Parameter(log_bandwidth, requires_grad=True)
        # self.single_gaussian_distribution = torch.distributions.Normal(
        #     torch.zeros_like(self.log_sigma[0]), torch.ones_like(self.log_sigma[0])
        # )

        if self.use_trainable_means:
            means = torch.randn(n_base_means, x_dim, device=device, dtype=dtype) * 0.1
            self.trainable_means = nn.Parameter(means, requires_grad=True)
        elif self.base_distribution == "mog" or self.base_distribution == "quasi_random":
            width = 1  # if not high enough, modes might not be separated enough once fed through the neural network
            random_means = torch.randn(n_base_means, z_dim, device=device, dtype=dtype) * width
            self.base_means = nn.Parameter(random_means, requires_grad=True)
        elif self.base_distribution == "betas":
            init_alpha = 1
            init_beta = 1
            self.alpha_unconstrained = nn.Parameter(
                torch.ones(self.z_dim, self.n_base_means) * init_alpha
            )
            self.beta_unconstrained = nn.Parameter(
                torch.ones(self.z_dim, self.n_base_means) * init_beta
            )
        elif self.base_distribution == "uniform":
            self.register_buffer("samples", None)
        if self.base_distribution == "quasi_random":
            self.eps_rel = 0.0
            self.register_buffer("quasi_r", None)

        self.trainable_params = {
            "model": self.network,
            "log_sigma": self.log_sigma,
        }

        if self.use_trainable_means:
            del self.trainable_params["model"]
            self.trainable_params["means"] = self.trainable_means
        elif self.base_distribution == "mog":
            self.trainable_params["base_means"] = self.base_means
        elif self.base_distribution == "betas":
            self.trainable_params["alpha"] = self.alpha_unconstrained
            self.trainable_params["beta"] = self.beta_unconstrained

        self.set_model_signature()

    def forward(self, x, context=None):
        return self.network(x, context=context)

    def sample_base_distribution(self, n_samples: int, context: torch.Tensor = None):
        # the base distribution is also a mixture of Gaussians (but could be any distribution)
        # Note: not to be confused with the mixture of Gaussian modeled by marginal flow
        if self.base_distribution == "mog":
            num_samples = n_samples if context is None else n_samples * context.shape[0]
            sigma = 0.5
            samples = sample_gaussian_mixture(
                n_samples=num_samples,
                means=self.base_means,
                sigma=torch.tensor(sigma, device=self.device),
            )

            return (
                samples
                if context is None
                else samples.reshape(context.shape[0], n_samples, self.z_dim)
            )
        elif self.base_distribution == "uniform":
            num_samples = n_samples if context is None else n_samples * context.shape[0]
            # if self.samples is None:
            samples = torch.rand(num_samples, self.z_dim, device=self.device) * 2 - 1
            return (
                samples
                if context is None
                else samples.reshape(context.shape[0], n_samples, self.z_dim)
            )
        elif self.base_distribution == "betas":

            n_samples = n_samples if context is None else n_samples * context.shape[0]

            alpha = F.softplus(self.alpha_unconstrained)
            beta = F.softplus(self.beta_unconstrained)

            alpha = alpha.unsqueeze(0)  # [1, d, K]
            beta = beta.unsqueeze(0)  # [1, d, K]

            component_ids = torch.randint(
                0, self.n_base_means, (n_samples, self.z_dim), device=alpha.device
            )  # [d]

            # Prepare gather indices
            component_ids = component_ids.unsqueeze(-1)  # [N, d, 1]
            alpha_sel = torch.gather(alpha.expand(n_samples, -1, -1), 2, component_ids).squeeze(
                -1
            )  # [N, d]
            beta_sel = torch.gather(beta.expand(n_samples, -1, -1), 2, component_ids).squeeze(
                -1
            )  # [N, d]

            # Sample from selected Beta distributions
            dist = Beta(alpha_sel, beta_sel)
            samples = dist.rsample().to(self.device)
            # samples = dist.rsample((n_samples,))

            return (
                samples
                if context is None
                else samples.reshape(context.shape[0], n_samples, self.z_dim)
            )
        elif self.base_distribution == "quasi_random":
            if self.quasi_r is None:
                with torch.no_grad():
                    self.quasi_r = sample_quasi_random_mixture(
                        means=self.base_means, sigma=1, n_samples=n_samples, eps_rel=0.0
                    )
                if self.z_dim == 1:  # additional sorting for visualizations
                    self.quasi_r, _ = torch.sort(self.quasi_r, dim=0)
                if self.eps_rel is None:
                    eps_rel = ((self.quasi_r[None] - self.quasi_r[:, None]) ** 2).sqrt().sum(-1)
                    eps_rel, _ = torch.sort(eps_rel, dim=-1)
                    eps_rel = eps_rel[:, 1].mean()
                    self.eps_rel = eps_rel
            eps = (
                0.2
                * self.eps_rel
                / np.sqrt(self.z_dim)
                * torch.randn(n_samples, self.z_dim, device=self.device, dtype=self.dtype)
            )
            return self.quasi_r + eps
        else:
            raise ValueError("Base distribution type not recognized: 'uniform' or 'mog'")

    def sample_and_log_prob(
        self, n_mixtures: int, n_samples: int, x: torch.Tensor = None, context: torch.Tensor = None
    ):
        # Note: this is different from calling first self.sample() and then self.log_prob()
        # 1) it is more efficient: self.log_prob() re-samples the mixtures
        # 2) the mixtures used in the parzen_log_prob are the same ones used to generate the samples
        samples, mixtures, base_samples = self.sample_all(
            n_mixtures=n_mixtures, n_samples=n_samples, context=context
        )
        sigma = self.log_sigma.exp()
        x = samples if x is None else x
        if self.mixture_distribution == "gaussian":
            log_prob = log_prob_gaussian_mixture(x=x, mixtures=mixtures, sigma=sigma)
        elif self.mixture_distribution == "beta":
            log_prob = log_prob_mixture_betas(x=x, mixtures=mixtures)
        else:
            raise ValueError("Mixture distribution not recognized: 'gaussian' or 'beta'")

        return mixtures, samples, log_prob, base_samples

    def log_prob(self, x: torch.Tensor, context: torch.Tensor = None, n_mixtures: int = 1024):
        _, mixtures = self.sample_base(n_mixtures, context=context)
        sigma = self.log_sigma.exp()
        if self.mixture_distribution == "gaussian":
            log_prob = log_prob_gaussian_mixture(x=x, mixtures=mixtures, sigma=sigma)
        elif self.mixture_distribution == "beta":
            log_prob = log_prob_mixture_betas(x=x, mixtures=mixtures)
        else:
            raise ValueError("Mixture distribution not recognized: 'gaussian' or 'beta'")

        return log_prob

    def sample_all(self, n_mixtures: int, n_samples: int, context: torch.Tensor = None):
        base_samples, mixtures = self.sample_base(n_mixtures=n_mixtures, context=context)
        # sample from so-defined Gaussian Mixture with diagonal covariance
        sigma = self.log_sigma.exp()
        if self.mixture_distribution == "gaussian":
            samples = sample_gaussian_mixture(n_samples=n_samples, means=mixtures, sigma=sigma)
        elif self.mixture_distribution == "beta":
            samples = sample_mixture_betas(n_samples=n_samples, mixtures=mixtures)
        else:
            raise ValueError("Mixture distribution not recognized: 'gaussian' or 'beta'")

        return samples, mixtures, base_samples

    def sample(
        self, n_samples: int, context: torch.Tensor = None, n_mixtures: int = 1024, **kwargs
    ):
        samples, mixtures, base_samples = self.sample_all(
            n_mixtures=n_mixtures, n_samples=n_samples, context=context
        )
        return samples

    def sample_base(self, n_mixtures: int, context: torch.Tensor = None):
        if self.use_trainable_means:
            resampling = False
            if resampling:
                sampled_idx = torch.randint(
                    0, self.trainable_means.shape[0], (n_mixtures,), device=self.device
                )
                return None, self.trainable_means[sampled_idx]
            else:
                return None, self.trainable_means
        else:
            # sample from fixed base distribution
            base_samples = self.sample_base_distribution(n_samples=n_mixtures, context=context)
            # transform samples and use as means of Gaussian Mixture
            mixtures = self.network(base_samples, context=context)

        return base_samples, mixtures
