from __future__ import annotations
import copy 
from collections.abc import MutableMapping
from math import pi
from typing import Any

import numpy as np
import gpytorch
import torch
from torch import nn
from torch.utils.data import DataLoader, IterableDataset

from typing import Any

from gpytorch.likelihoods import (
    GaussianLikelihood, 
    MultitaskGaussianLikelihood, 
    DirichletClassificationLikelihood
)

from bayesopt.surrogates.fsplaplace_utils.lanczos import lanczos_compute_efficient
from bayesopt.surrogates.fsplaplace_utils.linear_operators import (
    JacobianLinearOperator, 
    JacobianTransposeLinearOperator, 
    GGNLinearOperator
)
from bayesopt.surrogates.fsplaplace_utils.prior import (
    GPModel, 
    optimize_prior_parameters
)

from datasets import Dataset  # huggingface datasets

def fit_fsplaplace(
    model: nn.Module, 
    prior: GPModel, 
    likelihood: str,
    train_loader: DataLoader,
    context_loader: DataLoader, 
    device: str,
    dtype: torch.dtype,
    n_outputs: int,
    noise_var: float = 1e-1,
    lr: float = 0.001,
    n_epochs: int = 5000,
    val_frequency: int = 100,
    early_stopping_patience: int = 1000,
    jitter: float = 1e-10,
    chunk_size: int = 1,
    dict_key_y: str = "labels"
) -> torch.Tensor:
    # Define likelihood noise parameter
    noise_var = torch.nn.parameter.Parameter(
        noise_var * torch.ones(n_outputs, dtype=dtype, device=device)
    )

    # FSP-Laplace
    optimizer = torch.optim.Adam(list(model.parameters()) + [noise_var], lr=lr) # type: ignore
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs * len(train_loader))
    
    if likelihood == "regression":
        loss_func = torch.nn.GaussianNLLLoss(reduction="mean", full=True)
    elif likelihood == "classification":
        loss_func = torch.nn.NLLLoss(reduction="mean")
    else:
        raise ValueError("Invalid likelihood. Choose 'regression' or 'classification'.")

    # Early stopping
    best_epoch_loss = float("inf")
    no_improvement_count = 0
    best_net, best_noise_var = None, None

    n_samples = len(train_loader.dataset) # type: ignore

    # Training loop
    context_iter = iter(context_loader)
    for epoch in range(n_epochs):
        model.train()
        epoch_loss, epoch_acc = 0., 0.
        for data in train_loader:
            # Get data
            if isinstance(data, MutableMapping):
                x, y = data, data[dict_key_y].to(device, non_blocking=True)
            else:
                x, y = data
                x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            # Get context data
            try:
                x_context = next(context_iter)[0].to(device)
            except StopIteration:
                context_iter = iter(context_loader)
                x_context = next(context_iter)[0].to(device)
            # Zero grad
            optimizer.zero_grad()
            # Forward pass
            output = model(x)
            output_context = model(x_context)
            prior_mean, prior_cov = prior.prior_predictive(x_context, jitter)
            # Compute loss
            delta = output_context.T - prior_mean
            if likelihood == "regression":
                loss = loss_func(output, y, noise_var.expand(y.shape[0], -1))
            else:
                loss = loss_func(output, y)
            loss += 0.5 * torch.vmap(
                lambda _mu, _cov: torch.dot(_mu, torch.linalg.solve(_cov, _mu)), chunk_size=chunk_size
            )(delta, prior_cov).sum() / n_samples
            loss.backward()
            optimizer.step()
            scheduler.step()
            # Prevent the likelihood variance from becoming too small
            noise_var.data.clamp_(min=1e-6)
            # Logging
            epoch_loss += loss.item()
            if likelihood == "classification":
                epoch_acc += (output.argmax(dim=1) == y).sum().item()
            else:
                epoch_acc += torch.square(output - y).sum().item() 
        epoch_loss /= len(train_loader.dataset) # type: ignore
        epoch_acc /= len(train_loader.dataset) # type: ignore
        if epoch % val_frequency == 0:
            print(f"Epoch {epoch} - loss: {epoch_loss} - acc: {epoch_acc}", flush=True)
        #Early stopping
        if epoch_loss < best_epoch_loss:
            best_epoch_loss = epoch_loss
            best_net = copy.deepcopy(model.state_dict())
            best_noise_var = noise_var.clone()
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            if no_improvement_count > early_stopping_patience:
                print(f"Early stopping at epoch {epoch}", flush=True)
                break

    if best_net is not None:
        model.load_state_dict(best_net)
        noise_var = best_noise_var.detach() # type: ignore
            
    return noise_var#.item()


class FSPLaplace:
    def __init__(
        self,
        model: nn.Module,
        prior: gpytorch.models.ExactGP,
        likelihood: str,
        sigma_noise: torch.Tensor,
        params_sketch: str | None = None,
        params_sketch_dim: int = 10,
        temperature: float = 1.0,
        max_rank: int = 100,
        enable_backprop: bool = False,
        dict_key_x: str = "input_ids",
        dict_key_y: str = "labels",
        dtype=torch.float64
    ):
        assert likelihood in ["regression", "classification"], "Invalid likelihood. Choose 'regression' or 'classification'."

        self.model = model
        self.prior = prior
        self.params: dict[str,torch.Tensor] = {k:v.detach() for k,v in model.named_parameters()}

        self.n_params: int = sum(p.numel() for k,p in model.named_parameters())
        self.likelihood = likelihood
        self.sigma_noise = sigma_noise
        self.temperature = temperature
        self.enable_backprop = enable_backprop
        self.max_rank = max_rank
        self.params_sketch = params_sketch
        self.params_sketch_dim = params_sketch_dim
        self.dtype = dtype

        # For models with dict-like inputs (e.g. Huggingface LLMs)
        self.dict_key_x = dict_key_x
        self.dict_key_y = dict_key_y

        self.n_outputs: int = 0
        self.n_data: int = 0

        # Declare attributes
        self._sigma_noise: torch.Tensor
        self._posterior_scale: torch.Tensor | None
        

    @property
    def _device(self) -> torch.device:
        return next(self.model.parameters()).device
            
    def fit(
        self, 
        train_loader: DataLoader,
        context_loader: DataLoader,
        n_chunks: int = 1
    ) -> None:
        self.model.eval()

        data: (
            tuple[torch.Tensor, torch.Tensor] | MutableMapping[str, torch.Tensor | Any]
        ) = next(iter(train_loader))

        # Get n_outputs
        with torch.no_grad():
            if isinstance(data, MutableMapping):  # To support Huggingface dataset
                out = self.model(data)
            else:
                X = data[0]
                try:
                    out = self.model(X[:1].to(self._device))
                except (TypeError, AttributeError):
                    out = self.model(X.to(self._device))
        self.n_outputs = out.shape[-1]
        setattr(self.model, "output_size", self.n_outputs)

        n_context_points = len(context_loader.dataset) # type: ignore
        eps = torch.finfo(out.dtype).eps

        # Compute prior term
        M = []
        key = "features" if isinstance(context_loader.dataset, Dataset) else 0  # Huggingface dataset 
        X_context = torch.tensor(context_loader.dataset[:][key]).to(self._device, self.dtype)
        for output_idx in range(self.n_outputs):
            # Lanczos
            if isinstance(self.prior, GPModel):
                kernel_linop = self.prior.kernel_linop(X_context, output_idx=output_idx)
                b = kernel_linop @ torch.ones(n_context_points).to(X_context) # (n_context_points)
            else:
                b = self.prior.kernel_linop(X_context) @ torch.ones(n_context_points).to(X_context) # (n_context_points)
                kernel_linop = self.prior.kernel_linop(X_context)
            K_scale = lanczos_compute_efficient(
                kernel_linop, 
                b.reshape(-1), 
                tol=eps**0.5, # run lanczos until sq_tol < machine_eps 
                max_iter=self.max_rank
            ).detach()
            K_scale = K_scale.reshape(n_context_points, 1, -1)
            # Jacobian transpose
            jacT_linop = JacobianTransposeLinearOperator(
                self.model, 
                context_loader, 
                sketch=self.params_sketch, 
                sketch_dim=self.params_sketch_dim, 
                output_idx=output_idx, 
                n_outputs=self.n_outputs, 
                n_chunks=n_chunks
            )
            M += [jacT_linop @ K_scale] # (n_params, max_iters)

        del X_context, b, K_scale, kernel_linop
        
        # Concatenate outputs
        M = torch.cat(M, dim=-1) # (n_params, max_iters)
        
        # Eigenvalue decomposition
        _u, _s, _ = torch.linalg.svd(M, full_matrices=False)
        tol = eps**0.5 # machine precision for torch.diag(s**2) + uT_ggn_u
        s = _s[_s > tol]
        u = _u[:, _s > tol]
        self.prior_eigvals = s
        del M, _u, _s

        # GGN linear operator
        if self.likelihood == "regression":
            loss_fn = torch.nn.GaussianNLLLoss(reduction="sum", full=True)
        else:
            loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
        ggn_linop = GGNLinearOperator(
            model=self.model, 
            loss_fn=loss_fn, 
            sigma_noise=self.sigma_noise,
            data=train_loader,
            n_chunks=n_chunks,
            sketch=self.params_sketch,
            sketch_dim=self.params_sketch_dim
        )
    
        # Compute posterior scale
        uT_ggn_u = u.T @ (ggn_linop @ u) # (rk, rk)
        A = torch.diag(s**2) + uT_ggn_u
        _eigvals, _eigvecs = torch.linalg.eigh(A)
        eigvals = _eigvals[_eigvals > 0] # pseudo-inversion
        eigvecs = _eigvecs[:, _eigvals > 0]

        del A, _eigvals, _eigvecs

        # Sort eigenvectors in descending order
        eigvals = torch.flip(eigvals, dims=(0,))
        eigvecs = torch.flip(eigvecs, dims=(1,))

        # Marginal variance heuristic
        i = 0
        post_var = torch.zeros(n_context_points, self.n_outputs).to(eigvecs)
        prior_var = []
        for x_context in context_loader:
            key = "features" if isinstance(x_context, MutableMapping) else 0
            prior_var.append(
                self.prior.marginal_variance(x_context[key].to(self._device, self.dtype)).T.reshape(-1, self.n_outputs)
            )
        prior_var = torch.concatenate(prior_var)
        jac_linop = JacobianLinearOperator(
            self.model, 
            context_loader, 
            sketch=self.params_sketch, 
            sketch_dim=self.params_sketch_dim, 
            n_outputs=self.n_outputs, 
            n_chunks=n_chunks
        )
        cov_sqrt = []
        while torch.all(post_var < prior_var) and i < eigvals.shape[0]:
            cov_sqrt += [u @ eigvecs[:,i] * (1 / eigvals[i]**0.5)]
            post_var += torch.square(jac_linop @ cov_sqrt[-1])
            print(f"{i} - post_var={post_var.sum()} - prior_var={prior_var.sum()}", flush=True)
            i += 1
        truncation_idx = i-1 if torch.any(post_var > prior_var) and i > 1 else i
        self.posterior_scale = torch.stack(cov_sqrt[:truncation_idx], dim=-1) # (p, rk)
        self.posterior_eigvals = torch.linalg.svdvals(self.posterior_scale)


    def log_marginal_likelihood(
        self,
        dataloader,
        context_loader
    ):
        #
        if self.likelihood == "regression":
            loss_fn = torch.nn.GaussianNLLLoss(reduction="sum", full=True)
        else:
            loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
        log_likelihood = 0 
        for x, y in dataloader:
            f = self.model(x)
            log_likelihood -= loss_fn(f, y, torch.square(self.sigma_noise).expand(y.shape[0], -1))
        # 
        f_c = []
        for x_c in context_loader:
            f_c.append(self.model(x_c[0]))
        f_c = torch.concatenate(f_c).reshape(-1)
        X_context = context_loader.dataset.tensors[0].to(self._device)
        kernel_linop = self.prior.kernel_linop(X_context, output_idx=0)
        log_prior = -0.5 * torch.dot(f_c, kernel_linop @ f_c) - torch.log(self.prior_eigvals).sum()
        # 
        log_posterior = -torch.log(self.posterior_eigvals).sum()

        return log_likelihood + log_prior - log_posterior


    def __call__(
        self,
        x: torch.Tensor | MutableMapping,
        joint: bool = False,
        diagonal_output: bool = False,
        n_chunks: int = 1
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

        assert self.posterior_scale is not None, "Call `fit` method first."

        jac_linop = JacobianLinearOperator(
            self.model, 
            (x,), 
            sketch=self.params_sketch, 
            sketch_dim=self.params_sketch_dim, 
            n_outputs=self.n_outputs, 
            n_chunks=n_chunks
        )

        f_mu = self.model(x)
        f_scale = jac_linop @ self.posterior_scale # (batch, out, rk)
        f_scale = f_scale.reshape(f_mu.shape[0]*f_mu.shape[1], -1)  # (batch*out, rk)

        if joint:
            f_mu = f_mu.flatten()  # (batch*out)
            f_var = f_scale @ f_scale.T  # (batch*out, batch*out)
        else:
            f_scale = f_scale.reshape(f_mu.shape[0], f_mu.shape[1], -1)  # (batch, out, rk)
            f_var = torch.einsum("ncp,nkp->nck", f_scale, f_scale)  # (batch, out, out)

            if diagonal_output:
                f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)

        f_mu, f_var = (
            (f_mu.detach(), f_var.detach())
            if not self.enable_backprop
            else (f_mu, f_var)
        )

        if diagonal_output and not joint:
            f_var = torch.diagonal(f_var, dim1=-2, dim2=-1)
        
        return f_mu, f_var
    
    def predictive_samples(
        self,
        x: torch.Tensor | MutableMapping,
        n_samples: int = 100,
        generator: torch.Generator | None = None,
        n_chunks: int = 1
    ) -> torch.Tensor:
        assert self.posterior_scale is not None, "Call `fit` method first."
        eps = torch.randn(n_samples, self.posterior_scale.shape[-1], device=self._device, dtype=x.dtype, generator=generator)
        delta_samples = torch.einsum("pk,nk->pn", self.posterior_scale, eps)

        jac_linop = JacobianLinearOperator(
            self.model, 
            (x,), 
            sketch=self.params_sketch, 
            sketch_dim=self.params_sketch_dim, 
            n_outputs=self.n_outputs, 
            n_chunks=n_chunks
        )
        f_samples = self.model(x) + (jac_linop @ delta_samples).permute(2, 0, 1) 

        return f_samples.detach() if not self.enable_backprop else f_samples


class MAPRandomDataset(IterableDataset):
    """
    A custom IterableDataset that generates an infinite stream of random numbers.
    """
    def __init__(
        self, 
        batch_size, 
        input_size, 
        min_val, 
        max_val, 
        device, 
        dtype,
        distribution,
        generator_seed, 
    ):
        self.batch_size = batch_size
        self.input_size = input_size
        self.min_val = min_val
        self.max_val = max_val
        self.device = device
        self.dtype = dtype
        self.distribution = distribution
        self.generator = torch.Generator(device=self.device)
        if generator_seed is not None:
            self.generator.manual_seed(generator_seed)

    def __iter__(self):
        while True:  # This makes the dataset infinite
            # Generate random numbers using torch's uniform distribution
            if self.distribution == "uniform":
                data = torch.rand(
                    (1, self.batch_size, *self.input_size),
                    generator=self.generator,
                    device=self.device, 
                    dtype=self.dtype
                )
                # Scale to desired range
                data = data * (self.max_val - self.min_val) + self.min_val
            elif self.distribution == "uniform_bitstring":
                data = torch.randint(
                    0, 2, 
                    (1, self.batch_size, *self.input_size),
                    generator=self.generator,
                    device=self.device, 
                    dtype=self.dtype
                )
            else:
                raise ValueError(f"Invalid distribution: {self.distribution}")
            
            yield data


if __name__ == "__main__":
    dtype = torch.float64 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_outputs = 1
    likelihood = "regression"
    prior_name = "gp"
    params_sketch = "" #"ssrft"
    params_sketch_dim = 140
    
    # Data
    x1 = torch.linspace(-1, -0.5, 50).reshape(-1, 1).to(dtype)
    x2 = torch.linspace(0.5, 1, 50).reshape(-1, 1).to(dtype)
    
    # Train data 
    train_X = torch.cat([x1, x2], dim=0).to(dtype)
    # Context data
    context_dataset = MAPRandomDataset(
        batch_size=100, 
        input_size=train_X.shape[1:], 
        min_val=2, 
        max_val=-2, 
        device=device,
        dtype=dtype, 
        distribution="uniform",
        generator_seed=None
    )
    # Validation data
    val_X = torch.cat([x1, x2], dim=0).to(dtype)

    train_Y = torch.sin(2 * pi * train_X) + torch.normal(0, 0.1, (100, n_outputs)).to(dtype)
    val_Y = torch.sin(2 * pi * val_X) + torch.normal(0, 0.1, (100, n_outputs)).to(dtype)


    if likelihood == "classification":
        train_Y = torch.round(train_Y)
        val_Y = torch.round(val_Y)

    # Initialize model
    model = torch.nn.Sequential(
        torch.nn.Linear(train_X.shape[-1], 50, dtype=dtype),
        torch.nn.Tanh(), 
        torch.nn.Linear(50, 50, dtype=dtype),
        torch.nn.Tanh(), 
        torch.nn.Linear(50, train_Y.shape[-1], dtype=dtype),
    ).to(dtype)

    # Prior
    kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=1, batch_shape=torch.Size([n_outputs]), lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3)))
    # kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([n_outputs]), lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3)))
    #kernel = gpytorch.kernels.LinearKernel(batch_shape=torch.Size([n_outputs]))
    mean = gpytorch.means.ZeroMean(batch_shape=torch.Size([n_outputs]))
    # kernel.outputscale = 1.
    # kernel.base_kernel.lengthscale = 1.
    
    # Initialize GP model
    if likelihood == "regression":
        if n_outputs == 1:
            gp_likelihood = GaussianLikelihood()
            gp_likelihood.noise = 0.001
        else:
            gp_likelihood = MultitaskGaussianLikelihood(num_tasks=n_outputs, has_global_noise=False)
        prior = GPModel(mean, kernel, gp_likelihood, train_X, train_Y).to(device).to(dtype)
    else:
        gp_likelihood = DirichletClassificationLikelihood(train_Y, learn_additional_noise=True)
        prior = GPModel(mean, kernel, gp_likelihood, train_X, gp_likelihood.transformed_targets).to(device)

    # Prior maximum likelihood estimation
    prior, noise_var = optimize_prior_parameters(prior, gp_likelihood, train_X, train_Y, n_steps=100, verbose=True)
    prior.eval()
    gp_likelihood.eval()


    train_dataset = torch.utils.data.TensorDataset(train_X, train_Y) # type: ignore
    val_dataset = torch.utils.data.TensorDataset(val_X, val_Y) # type: ignore
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100) # type: ignore
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100) # type: ignore
    context_loader = torch.utils.data.DataLoader(context_dataset, batch_size=None) # type: ignore

    var_noise = fit_fsplaplace(
        model, 
        prior,  # type: ignore
        likelihood,
        train_loader, 
        context_loader, 
        device, 
        dtype, 
        n_outputs, 
        noise_var=noise_var, #1e-1,
        lr=0.001,
        n_epochs=5000,
        val_frequency=100,
        early_stopping_patience=1000,
        jitter=1e-10
    )

    model.eval()
    laplace = FSPLaplace(
        model, 
        prior=prior, 
        likelihood=likelihood,
        sigma_noise=var_noise**0.5, 
        params_sketch=params_sketch,
        params_sketch_dim=params_sketch_dim, 
        dtype=dtype
    )
    context_X = torch.FloatTensor(5000, 1).uniform_(-2, 2).to(dtype)
    context_dataset = torch.utils.data.TensorDataset(context_X) # type: ignore
    context_loader = torch.utils.data.DataLoader(context_dataset, batch_size=100) # type: ignore

    laplace.fit(train_loader, context_loader)
    mu, cov = laplace(train_X)
    print(mu.shape, cov.shape)
    # RBF - tanh : -174'019.2336
    # Matern 1/2 - tanh : -41'608.1547
    # Linear - tanh : -105'251.9863
    
    import matplotlib.pyplot as plt
    n_samples = 100
    x = torch.arange(-2, 2, 0.01).reshape(-1, 1).to(dtype)
    samples = laplace.predictive_samples(x, n_samples=n_samples).reshape(n_samples, -1, 1).cpu().numpy()
    
    # Format input 
    x = x.reshape(-1).cpu().numpy()
    samples = np.squeeze(samples)
    pred_mean = np.squeeze(samples.mean(0))
    pred_std = np.squeeze(samples.std(0))

    # Plot train data
    X_train, y_train = train_X, train_Y
    plt.scatter(
        X_train.reshape(-1).cpu().numpy(), 
        y_train.reshape(-1).cpu().numpy(), 
        marker="o",
        facecolors='lightgrey',
        edgecolors='dimgrey',
        linewidth=1,
        s=10
    )

    # Plot predictive mean
    plt.plot(
        x, 
        pred_mean, 
        c="#e41a1c", #"red",
        linewidth=2
    )

    # Plot predictive std dev
    plt.fill_between(
        x, 
        pred_mean-2*pred_std, 
        pred_mean+2*pred_std, 
        color="#2ca02c",
        alpha=0.2
    )

    # Plot individual mean functions
    for i in range(samples.shape[0]):
        plt.plot(
            x, 
            samples[i].reshape(-1),
            c="#2ca02c", #"forestgreen",
            alpha=0.5,
            linewidth=1
        )
    plt.ylim(-2, 2)
    plt.show()
