import numpy as np
import torch
import random
import json
from hashlib import sha1
from einops import reduce
from collections import defaultdict
import itertools
import time

import torch
import torch.nn as nn
from torch.distributions.normal import Normal

#loss function with rel/abs Lp loss
class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super().__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        if type(x) == tuple:
            x = x[0]
        return self.rel(x, y)
    

# normalization, pointwise gaussian
class UnitGaussianNormalizer:
    def __init__(self, x, eps=0.00001, reduce_dim=[0], verbose=True):
        super().__init__()
        n_samples, *shape = x.shape
        self.sample_shape = shape
        self.verbose = verbose
        self.reduce_dim = reduce_dim

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, reduce_dim, keepdim=True).squeeze(0)
        self.std = torch.std(x, reduce_dim, keepdim=True).squeeze(0)
        self.eps = eps
        
        if verbose:
            print(f'UnitGaussianNormalizer init on {n_samples}, reducing over {reduce_dim}, samples of shape {shape}.')
            print(f'   Mean and std of shape {self.mean.shape}, eps={eps}')

    def encode(self, x):
        # x = x.view(-1, *self.sample_shape)
        x = x - self.mean
        x = x / (self.std + self.eps)
        # x = (x.view(-1, *self.sample_shape) - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx]+ self.eps # T*batch*n
                mean = self.mean[:,sample_idx]

        # x is in shape of batch*n or T*batch*n
        # x = (x.view(self.sample_shape) * std) + mean
        # x = x.view(-1, *self.sample_shape)
        x = x * std
        x = x + mean

        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()
        return self

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
        return self
    
    def to(self, device):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)
        return self


def nll_mu_var(out, y):
    # out: Tuple of tensors (mu, var, ...)
    mu, var = out[0], out[1]
    nll = ((mu - y).pow(2)/var + torch.log(var)).sum()
    return nll

class CRPSLoss(nn.Module):
    def __init__(self):
        super(CRPSLoss, self).__init__()

    def forward(self, out, y):
        """
        Compute the CRPS loss for a normal distribution.

        Parameters:
        mu (torch.Tensor): Predicted mean (batch_size,).
        sigma (torch.Tensor): Predicted standard deviation (batch_size,).
        y (torch.Tensor): Ground truth observation (batch_size,).

        Returns:
        torch.Tensor: The computed CRPS loss (scalar).
        """
        # Standard normal CDF and PDF

        mu, var = out[0], out[1]
        sigma = torch.sqrt(var)
        # print(mu.shape, sigma.shape)
        dist = Normal(torch.zeros_like(mu), torch.ones_like(sigma))
        z = (y - mu) / sigma  # Standardized value
        cdf_z = dist.cdf(z)
        pdf_z = dist.log_prob(z).exp()

        # Compute the CRPS
        crps = sigma * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi)))

        return crps.mean(dim=[1,2,3]).sum()

# # Example usage
# mu = torch.tensor([0.5, 1.0], requires_grad=True)   # Example predicted means
# var = torch.tensor([0.2, 0.3], requires_grad=True)  # Example predicted variance
# y = torch.tensor([0.6, 1.1])                         # Example true observations

# crps_loss_fn = CRPSLoss()
# loss = crps_loss_fn(mu, var, y)
# print("CRPS Loss:", loss.item())


# Metrics
def compute_mse_by_t(mu, y, reduce="sum"):
    # out & y: nf nx nt d
    test_mse_by_t = ((mu - y)**2).sum(dim=[0, 1, 3]) / y.shape[1]
    if reduce == "mean":
        # Mean over n_samples
        test_mse_by_t /= y.shape[0]
    return test_mse_by_t

def compute_mse_by_example(mu, var, y):
    # out & y: nf nx nt d
    test_mse_by_example = ((mu - y)**2).mean(dim=[1, 2, 3])
    return test_mse_by_example

def compute_nll_by_example(mu, var, y):
    nll_by_example = ((mu - y).pow(2)/var + torch.log(2 * np.pi * var)).sum(dim=[1,2,3]) / 2
    return nll_by_example

def compute_nMeRCI(mu, var, y, alpha=0.95):
    # Compute n-MeRCI (normalized Mean Rescaled Confidence Interval) for correlation between uncertainty and errors.
    # Papers: https://arxiv.org/pdf/1908.07253.pdf, https://www.sciencedirect.com/science/article/pii/S0045782522004595#b55
    # Smaller values (closer to zero) is better.

    mae = torch.abs(mu - y).sum(dim=[1,2,3])
    std = torch.sqrt(var.sum(dim=[1,2,3]))
    lamda = mae / std
    lamda_alpha = torch.quantile(lamda, alpha)

    # Should be equal to alpha
    # print((mae <= std * lamda_alpha).float().mean())

    num = (lamda_alpha * std).mean() - mae.mean()
    denom = mae.max() - mae.mean()
    return num/denom


def compute_rmsce(mu, var, y, nbins=10):
    # Compute root mean squared calibration error.
    dist = torch.distributions.Normal(mu, torch.sqrt(var)+1e-10)
    ps = torch.linspace(0, 1, nbins+1)
    calibration_err = [(p - (y <= dist.icdf(p)).float().mean(dim=0))**2 for p in ps] 
    calibration_err = torch.stack(calibration_err).mean(dim=0).sqrt()
    return calibration_err.mean()


def compute_crps_by_example(mu, var, y):
    sigma = torch.sqrt(var)
    # print(mu.shape, sigma.shape)
    dist = Normal(torch.zeros_like(mu), torch.ones_like(sigma))
    z = (y - mu) / sigma  # Standardized value
    cdf_z = dist.cdf(z)
    pdf_z = dist.log_prob(z).exp()

    # Compute the CRPS
    crps = sigma * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi)))
    return crps.mean(dim=[1,2,3])

def compute_sampling_crps_by_example(mu, var, y, nbins=10):
    # Compute Continuous Ranked Probability Score (CRPS)
    # (https://www.jstor.org/stable/23243806?seq=4, https://arxiv.org/pdf/2102.00968.pdf)
    ps = torch.linspace(0, 1, nbins+1)[1:-1]
    dist = torch.distributions.Normal(mu, torch.sqrt(var)+1e-10)
    crps = 0.
    for p in ps:
        y_pred_at_p = dist.icdf(p)
        ql_p = ((y_pred_at_p > y).int() - p) * (y_pred_at_p - y)
        crps += ql_p
    crps *= 2/len(ps)
    return crps.mean(dim=[1,2,3])

def compute_piw_by_example(mu, var, y):
    # Assumes p=0.95
    std = torch.sqrt(var)
    piw = 2 * 1.96 * std 
    return piw.mean(dim=[1,2,3])

def compute_forward_time(model, x, repetitions=100):
    warmup = repetitions // 10
    times = []
    for i in range(warmup + repetitions):
        t0 = time.time()
        _ = model(x)
        torch.cuda.current_stream().synchronize()
        time_taken = (time.time() - t0) * 1000   # in ms
        if i >= warmup:
            times.append(time_taken)
    return np.mean(times)

def compute_n_params(model):
    n_params = 0
    for p in model.parameters():
        n_params += np.prod(p.shape)
    return int(n_params)

def compute_n_flops(model_name, Np, fno_modes, fno_width, n_layers, n_models):
    # Assumes d_i = d_o = 1
    lifting_layer = 2 * Np * fno_width
    fourier_layer = 10 * fno_width * Np * np.log2(Np) + fno_modes * (2 * fno_width**2 - fno_width) + 2 * Np * fno_width**2
    projection_layer = 2 * Np * fno_width

    if model_name.lower() == 'EnsembleFNO2d'.lower():
        n_flops = n_models * (lifting_layer + n_layers * fourier_layer + projection_layer)
    elif model_name.lower() == 'DiverseFNO2d'.lower():
        n_flops = lifting_layer + n_layers * fourier_layer + n_models * projection_layer
    else:
        n_flops = -1

    return int(n_flops)

def compute_all_metrics(out, y, results, metrics=None):
    if type(out) == tuple:
        mu, var = out[0], out[1]
    else:
        mu = out
        var = torch.zeros_like(mu) + 1e-20
    
    if metrics is None:
        metrics = ["mse", "nll", "piw", "crps"]

    results_ = {}

    for metric in metrics:
        metric_fn = globals()[f"compute_{metric}_by_example"]
        results_[f"{metric}_by_example"] = metric_fn(mu, var, y).detach().cpu()
        results_[metric] = results_[f"{metric}_by_example"].sum().item()

    for key in results_.keys():
        if key not in results:
            results[key] = results_[key]
        else:
            if key.endswith("by_example"):
                results[key] = torch.cat([results[key], results_[key]], dim=0)
            else:
                results[key] += results_[key]
    return results


def compute_all_metrics_avg(out, y, results, metrics=None):
    if type(out) == tuple:
        mu, var = out[0], out[1]
    else:
        mu = out
        var = torch.zeros_like(mu) + 1e-20
    
    if metrics is None:
        metrics = ["mse", "nll", "piw", "crps"]

    results_ = {}

    for metric in metrics:
        metric_fn = globals()[f"compute_{metric}_by_example"]
        results_[f"{metric}_by_example"] = metric_fn(mu, var, y).detach().cpu()
        results_[metric] = results_[f"{metric}_by_example"].sum().item()

    for key in results_.keys():
        if key not in results:
            results[key] = results_[key]
        else:
            if key.endswith("by_example"):
                results[key] = torch.cat([results[key], results_[key]], dim=0)
            else:
                results[key] += results_[key]

    for key in results.keys():
        if not key.endswith("by_example"):
            results[key] /= len(y)
        if type(results[key]) == torch.Tensor:
            results[key] = results[key].tolist()

    return results


def plot_at_time(index, t_plot, t_sliced, grid, y, mu, std, ax, **kwargs):
    ax.plot(grid, y[index, :, t_plot], label=f"True Solution (t={t_sliced[t_plot]:.1f})")
    ax.plot(grid, mu[index, :, t_plot], label=f"Predicted (t={t_sliced[t_plot]:.1f})")
    if std is not None:
        ax.fill_between(grid, mu[index, :, t_plot]-3*std[index, :, t_plot], mu[index, :, t_plot]+3*std[index, :, t_plot], color='b', alpha=0.1)
    ax.set_xlabel(r"$x$")
    ax.set_ylabel(r"$u(x, t)$")
    ax.legend()
    
    if "ylim" in kwargs:
        ax.set_ylim(*kwargs["ylim"])
    if "title" in kwargs:
        ax[0].set_title(kwargs["title"])
    # ax[0].set_xlim(-0.05, 1.05)
    # ax[0].set_ylim(-1, 1)

def set_seed(seed=314_271):
    # Set seed for random, numpy and torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=True

def config_to_hash(config):
    config_repr = json.dumps(config, sort_keys=True)
    return sha1(config_repr.encode()).hexdigest()

def dict_to_file(d, filepath):
    with open(filepath, 'w') as f:
        json.dump(d, f)

def filter_config(config, keys, mode="remove", new_config=None):
    if mode == "remove":
        if new_config is None:
            new_config = config.copy()
        new_config_keys = list(new_config.keys())
        for key in new_config_keys:
            if key in keys or "." in key:
                new_config.pop(key)
    elif mode == "add":
        if new_config is None:
            new_config = {}
        for key in keys:
            new_config[key] = config[key]

    return new_config


def generate_commands(filename, datasets, models, other_config, seed=0):
    if type(seed) == int:
        seed = [seed]

    commands = []

    for s in seed:
        for dataset_name, dataset_config in datasets.items():
            for model_name, model_config in models.items():
                dataset_name = dataset_name.split(":")[0]
                model_name = model_name.split(":")[0]

                command = f"python -u {filename} "
                command += f"--model={model_name} "
                command += f"--dataset={dataset_name} "
                command += f"--seed={s} "

                # Dataset & Model parameters
                config = dataset_config | model_config | other_config
                values = [[f"{k}" if vi=="" else f"{k}={vi}" for vi in v] for k,v in config.items()]
                for p in itertools.product(*values):
                    commands.append(command + " ".join(p))

    return commands


def nll_mu_cov(pred, target):
    """
    Negative log-likelihood loss for multivariate Gaussian.
    pred = (mu, cov) where cov is a full covariance matrix (..., n, n)
    target has shape (..., n)
    """
    mu, cov = pred
    eps = 1e-4  # Regularization for positive definiteness
    batch_shape = target.shape[:-1]
    n = target.shape[-1]

    # Add small identity for numerical stability
    cov_stable = cov + eps * torch.eye(n, device=cov.device).expand(*batch_shape, n, n)

    # Inverse and log-det
    cov_inv = torch.linalg.inv(cov_stable)
    diff = target - mu
    mahalanobis = torch.einsum('...i,...ij,...j->...', diff, cov_inv, diff)
    logdet = torch.logdet(cov_stable)

    loss = 0.5 * (mahalanobis + logdet + n * torch.log(torch.tensor(2 * np.pi, device=cov.device)))
    return loss.mean()


def compute_statistics(
    model, 
    x_data, 
    y_data, 
    t, 
    tpred, 
    grid, 
    dataset_class, 
    apply_probconserv=False, 
    plot=False,
    x_data_test=None, 
    y_data_test=None,
    return_latex=False,
    name="Model"
):
    import torch
    import utils
    import probconserv
    import matplotlib.pyplot as plt

    device = next(model.parameters()).device
    x_data = x_data.to(device)

    with torch.no_grad():
        out = model(x_data)

    if isinstance(out, tuple):
        mu, var = out[0].cpu(), out[1].cpu()
        std = torch.sqrt(var)
    else:
        mu = out.cpu()
        std = torch.zeros_like(mu)
        var = torch.square(std)

    x_cpu = x_data.cpu()
    mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_cpu)

    if apply_probconserv:
        new_mu, new_std, _, mass_rhs = probconserv.apply_constraint(
            mu=mu[:, :, :, 0],
            std=std[:, :, :, 0],
            mass_rhs_func=mass_rhs_func,
            t=t,
            tpred=tpred,
            grid_train=grid,
            precis_g=float('inf'),
            second_deriv_alpha=None,
        )
        mu = new_mu.unsqueeze(-1)
        std = new_std.unsqueeze(-1)
        var = torch.square(std)
        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs).abs().sum(dim=-1)
    else:
        cerr = (probconserv.get_empirical_mass_rhs(mu[:, :, :, 0]) - mass_rhs_func(rearrange(x_cpu, "nf nx nt 1-> nf nt nx 1"))).abs().sum(dim=-1)

    stats = utils.compute_all_metrics_avg((mu, var), y_data, {})
    stats["nMeRCI_all"] = utils.compute_nMeRCI(mu, var, y_data).item()
    stats["rmsce_all"] = utils.compute_rmsce(mu, var, y_data).item()
    stats["cerr_by_example"] = cerr.tolist()
    stats["mcerr"] = cerr.mean().item()

    # --- Test dataset ---
    test_stats = None
    if x_data_test is not None and y_data_test is not None:
        x_data_test = x_data_test.to(device)
        with torch.no_grad():
            test_out = model(x_data_test)

        if isinstance(test_out, tuple):
            mu_test, var_test = test_out[0].cpu(), test_out[1].cpu()
            std_test = torch.sqrt(var_test)
        else:
            mu_test = test_out.cpu()
            std_test = torch.zeros_like(mu_test)
            var_test = torch.square(std_test)

        x_test_cpu = x_data_test.cpu()
        test_mass_rhs_func = dataset_class.get_mass_rhs_func(x=x_test_cpu)

        if apply_probconserv:
            new_mu_test, new_std_test, _, test_mass_rhs = probconserv.apply_constraint(
                mu=mu_test[:, :, :, 0],
                std=std_test[:, :, :, 0],
                mass_rhs_func=test_mass_rhs_func,
                t=t,
                tpred=tpred,
                grid_train=grid,
                precis_g=float('inf'),
                second_deriv_alpha=None,
            )
            mu_test = new_mu_test.unsqueeze(-1)
            std_test = new_std_test.unsqueeze(-1)
            var_test = torch.square(std_test)
            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs).abs().sum(dim=-1)
        else:
            cerr_test = (probconserv.get_empirical_mass_rhs(mu_test[:, :, :, 0]) - test_mass_rhs_func(rearrange(x_test_cpu, "nf nx nt 1-> nf nt nx 1"))).abs().sum(dim=-1)

        test_stats = utils.compute_all_metrics_avg((mu_test, var_test), y_data_test, {})
        test_stats["nMeRCI_all"] = utils.compute_nMeRCI(mu_test, var_test, y_data_test).item()
        test_stats["rmsce_all"] = utils.compute_rmsce(mu_test, var_test, y_data_test).item()
        test_stats["cerr_by_example"] = cerr_test.tolist()
        test_stats["mcerr"] = cerr_test.mean().item()

    # --- Optional plot ---
    if plot:
        t_idx = 1
        param_idx = 0
        with torch.no_grad():
            plt.ylabel(f"u(x, t={t[slice(*tpred)][t_idx]:.2f})")
            plt.xlabel("x")
            plt.title(f"Predicted vs True (param = {x_data[param_idx,0,0,0].item():.2f})")
            mu_plot = mu[param_idx, :, t_idx, 0]
            std_plot = std[param_idx, :, t_idx, 0]
            y_true_plot = y_data[param_idx, :, t_idx, 0]
            plt.plot(grid, mu_plot, '--', lw=2, label="μ ± 3σ")
            plt.fill_between(grid, mu_plot + 3*std_plot, mu_plot - 3*std_plot, alpha=0.2)
            plt.plot(grid, y_true_plot, color="green", label="true")
            plt.legend()
            plt.show()

    # --- Optional LaTeX row ---
    latex_row = None
    if return_latex and test_stats:
        latex_row = (
            f"{name} & "
            f"{stats['mse']:.2E} & {stats['nMeRCI_all']:.2E} & {stats['rmsce_all']:.2E} & {stats['mcerr']:.2E} & {stats['crps']:.2E} & "
            f"{test_stats['mse']:.2E} & {test_stats['nMeRCI_all']:.2E} & {test_stats['rmsce_all']:.2E} & {test_stats['mcerr']:.2E} & {test_stats['crps']:.2E} \\\\"
        )

    return (stats, test_stats, latex_row) if return_latex else (stats, test_stats)

def tvsolver(x, lmd, maxiters=100):
    """
    Pytorch solver for the 1D total variation denoising problem. Operates on a batch of inputs
    x of size (b, n) all with same lambda.
    """

    # initialize u to x
    b, n = x.shape
    u = x.clone()
    
    # iterate
    Dx = (x[:, :-1] - x[:, 1:]).view(b, n - 1, 1)
    ones = torch.ones((b, n - 2), dtype=x.dtype, device=x.device)
    for itr in range(maxiters):
        print(itr)
        L = torch.abs(u[:, :-1] - u[:, 1:]) / lmd
        H = torch.diag_embed(L + 2.0) - torch.diag_embed(ones, offset=1) - torch.diag_embed(ones, offset=-1)
        w = torch.cholesky_solve(Dx, torch.linalg.cholesky(H)).view(b, n - 1)
        
        u = x.clone()
        u[:, 1:] += w
        u[:, :-1] -= w
        
    return u

class TotalVariationFcn(torch.autograd.Function):
    """PyTorch autograd function for total variation denoising."""

    @staticmethod
    def forward(ctx, x, lmd):
        with torch.no_grad():
            u = tvsolver(x.detach().clone(), lmd)
        return u

    @staticmethod
    def backward(ctx, dLdY):
        return dLdY, None