import torch
from torch import nn
from pathlib import Path
from dataclasses import dataclass
from typing import Union
from sbibm.metrics import c2st
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

def compute_rmse(samples: torch.Tensor, true_theta: torch.Tensor) -> float:
    return torch.sqrt(((samples.mean(0) - true_theta) ** 2).mean()).item()


def compute_coverage(samples: torch.Tensor, theta_true: torch.Tensor, alpha: float = 0.1):
    if samples.dim() == 1:
        samples = samples.unsqueeze(-1)
    
    theta_true = theta_true.flatten()
    
    L = torch.quantile(samples, alpha / 2, dim=0) 
    U = torch.quantile(samples, 1 - alpha / 2, dim=0)
    
    coverage_per_dim = ((theta_true >= L) & (theta_true <= U)).float()
    
    return coverage_per_dim.cpu().numpy(), coverage_per_dim.mean().item()


def compute_posterior_mmd(samples: torch.Tensor, true_samples: torch.Tensor, device: str) -> float:
    mmd_fn = MMDLoss().to(device)
    return mmd_fn(samples.to(device), true_samples.to(device)).item()


def compute_predictive_mmd(samples, x_clean, task, device, seed=None, n_pred=150, return_samples=False):
    mmd_fn = MMDLoss().to(device)
    idx = torch.randperm(len(samples))[:n_pred]
    theta_sub = samples[idx].to(device)
    
    sim_seed = seed if seed is not None else np.random.randint(1e6)
    
    with torch.no_grad():
        x_pred = task.simulate(theta_sub, seed=sim_seed)
    
    if torch.isnan(x_pred).any():
        return float('nan')
    
    if x_pred.ndim == 4:
        n_pred_actual, N, H, W = x_pred.shape
        x_pred_flat = x_pred.reshape(n_pred_actual * N, H * W).to(device)
        x_clean_flat = x_clean.reshape(N, H * W).to(device)
    elif x_pred.ndim == 3:
        n_pred_actual, N, feat_dim = x_pred.shape
        x_pred_flat = x_pred.reshape(n_pred_actual * N, feat_dim).to(device)
        x_clean_flat = x_clean.reshape(N, feat_dim).to(device)
    elif x_pred.ndim == 2:
        x_pred_flat = x_pred.to(device)
        x_clean_flat = x_clean.unsqueeze(0).to(device) if x_clean.ndim == 1 else x_clean.to(device)
    else:
        raise ValueError(f"Unexpected x_pred shape: {x_pred.shape}")
    
    mmd_value = mmd_fn(x_pred_flat, x_clean_flat).item()
    
    if return_samples:
        return mmd_value, x_pred.cpu()
    return mmd_value

class RBF(nn.Module):
    def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None, eps=1e-8):
        super().__init__()
        multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
        self.register_buffer("bandwidth_multipliers", multipliers)
        self.bandwidth = bandwidth
        self.eps = eps

    def get_bandwidth(self, L2_distances: torch.Tensor) -> torch.Tensor:
        if self.bandwidth is None:
            n = L2_distances.shape[0]
            if n < 2:
                return L2_distances.new_tensor(1.0)
            bw = L2_distances.detach().sum() / (n * n - n)
        else:
            bw = self.bandwidth
            if not torch.is_tensor(bw):
                bw = L2_distances.new_tensor(float(bw))
            else:
                bw = bw.to(device=L2_distances.device, dtype=L2_distances.dtype)

        return bw.clamp_min(self.eps)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        L2_distances = torch.cdist(X, X).pow(2)
        bw = self.get_bandwidth(L2_distances)

        denom = (bw * self.bandwidth_multipliers.to(dtype=X.dtype))[:, None, None] 
        K = torch.exp(-L2_distances[None, ...] / denom).sum(dim=0)

        return K


class MMDLoss(nn.Module):
    def __init__(self, kernel=None):
        super().__init__()
        self.kernel = RBF() if kernel is None else kernel

    def forward(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        if X.dim() == 1:
            X = X.unsqueeze(0)
        if Y.dim() == 1:
            Y = Y.unsqueeze(0)

        K = self.kernel(torch.cat([X, Y], dim=0))

        n = X.shape[0]
        XX = K[:n, :n].mean()
        XY = K[:n, n:].mean()
        YY = K[n:, n:].mean()
        return XX - 2 * XY + YY