import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt
from typing import List, Tuple


# Define PNN class
class ProbabilisticNet(nn.Module):
    """
    A probabilistic neural network that outputs mean and variance predictions.
    We use this as a member of our Deep Ensemble.
    """

    def __init__(self, input_dim: int = 1, hidden_dim: int = 64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Softplus(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Softplus(),
            # Output both mean and log variance
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        output = self.network(x)
        # Split output into mean and log variance
        mean, logvar = output.chunk(2, dim=-1)
        # Bound the variance for numerical stability
        var = torch.exp(torch.clamp(logvar, -10, 5))
        return mean, var


# Define the Deep Ensemble class
class DeepEnsemble:
    """
    Implementation of a Deep Ensemble of probabilistic neural networks.
    """

    def __init__(self, n_ensemble: int = 5, input_dim: int = 1, hidden_dim: int = 32, lr: float = 0.01):
        # self.models = [ProbabilisticNet(input_dim, hidden_dim) for _ in range(n_ensemble)]
        # self.optimizers = [optim.Adam(model.parameters(), lr=lr) for model in self.models]

        # Initialize the ensemble of models - single optimizer for all models
        self.models = nn.ModuleList(
            [ProbabilisticNet(input_dim, hidden_dim) for _ in range(n_ensemble)]
        )
        self.optimizer = optim.Adam(self.models.parameters(), lr=lr)

    # def train_step(self, x: torch.Tensor, y: torch.Tensor):
    #     """Train all ensemble members on the given batch."""
    #     losses = []
    #     for model, optimizer in zip(self.models, self.optimizers):
    #         mean, var = model(x)
    #         # Gaussian negative log likelihood loss
    #         loss = 0.5 * (torch.log(var) + (y - mean) ** 2 / var).mean()
    #
    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #         losses.append(loss.item())
    #     return np.mean(losses)

    def train_step(self, x: torch.Tensor, y: torch.Tensor) -> float:
        """Train all ensemble members on the given batch."""
        self.optimizer.zero_grad()
        losses = []
        for model in self.models:
            mean, var = model(x)
            # Gaussian negative log likelihood loss
            loss = 0.5 * (torch.log(var) + (y - mean) ** 2 / var).mean()
            loss.backward()
            losses.append(loss.item())
        self.optimizer.step()
        return np.mean(losses)

    def predict(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get predictions from all ensemble members."""
        means, var_s = [], []
        for model in self.models:
            mean, var = model(x)
            means.append(mean)
            var_s.append(var)

        # Stack predictions from all members
        means = torch.stack(means)  # [n_ensemble, batch_size, 1]
        var_s = torch.stack(var_s)  # [n_ensemble, batch_size, 1]

        return means, var_s


def compute_entropy(means: torch.Tensor, vars: torch.Tensor) -> torch.Tensor:
    """
    Compute total predictive entropy of the ensemble.

    The formula used is H = log(sqrt(2πe * σ²)) where σ² is the total variance.
    We use this form to ensure positive entropy values while maintaining
    the same relative relationships between different uncertainty levels.
    """
    # Compute total variance (aleatoric + epistemic)
    total_var = vars.mean(0) + means.var(0)

    # Compute entropy ensuring positive values
    # We take the square root of the variance to work with standard deviation
    # This gives us a more intuitive scale and ensures positive values
    entropy = torch.log(torch.sqrt(2 * np.pi * np.e * total_var))
    return entropy.squeeze()


def compute_ig(means: torch.Tensor, vars: torch.Tensor) -> torch.Tensor:
    """Compute Information Gain (ensemble disagreement measure)."""
    # Compute KL divergence between individual predictions and mixture
    avg_entropy = 0.5 * torch.log(2 * np.pi * np.e * vars).mean(0)
    avg_entropy = avg_entropy.reshape(1, -1)
    mixture_entropy = compute_entropy(means, vars)
    return (mixture_entropy - avg_entropy).squeeze()


# For plotting purposes
def compute_variance(means: torch.Tensor, vars: torch.Tensor) -> torch.Tensor:
    """Compute total predictive variance of the ensemble."""
    # Compute the mean of the predictive variances first
    total_var = vars.mean(0) + means.var(0)
    return total_var.squeeze()


def generate_data(n: int, base_noise: float = 0.2, noise_type: str = 'heteroskedastic') -> Tuple[np.ndarray, np.ndarray]:
    """Generate toy regression dataset."""
    x = np.linspace(-3, 3, n)
    y = 0.5 * x + np.sin(1.5 * x)
    if noise_type == 'homoskedaistic':
        y += np.random.normal(0, base_noise, n)
    else:
        y += np.random.normal(0, base_noise + 0.6 - 0.2 * np.abs(x), n)

    x, y = x.reshape(-1, 1), y.reshape(-1, 1)
    return x, y


if __name__ == "__main__":

    # Set seed for reproducibility
    np.random.seed(13)
    torch.manual_seed(13)

    # Generate data
    n = 100
    x_hom, y_hom = generate_data(n, noise_type='homoskedaistic')
    x_het, y_het = generate_data(n, noise_type='heteroskedastic')

    # Train ensemble on both datasets
    m = 5
    ensemble_hom, ensemble_het = DeepEnsemble(n_ensemble=m), DeepEnsemble(n_ensemble=m)

    x_hom_tensor, y_hom_tensor = torch.from_numpy(x_hom).float(), torch.from_numpy(y_hom).float()
    x_het_tensor, y_het_tensor = torch.from_numpy(x_het).float(), torch.from_numpy(y_het).float()

    n_epochs = 500
    for _ in range(n_epochs):
        loss_hom = ensemble_hom.train_step(x_hom_tensor, y_hom_tensor)
        loss_het = ensemble_het.train_step(x_het_tensor, y_het_tensor)
        if _ % 10 == 0:
            print(f'Epoch: {_}, Loss (hom): {loss_hom:.4f}, Loss (het): {loss_het:.4f}')

    # Compute ensemble predictions for text points
    x_test = torch.linspace(-6, 6, 100).unsqueeze(-1)
    means_hom, vars_hom = ensemble_hom.predict(x_test)
    means_het, vars_het = ensemble_het.predict(x_test)

    # Compute entropy and information gain for the two ensembles
    entropy_hom, entropy_het = compute_entropy(means_hom, vars_hom), compute_entropy(means_het, vars_het)
    ig_hom, ig_het = compute_ig(means_hom, vars_hom), compute_ig(means_het, vars_het)

    # Compute total predictive variance for the two ensembles for plotting
    vars_hom, vars_het = compute_variance(means_hom, vars_hom), compute_variance(means_het, vars_het)

    # Transform tensors to numpy arrays for plotting
    means_hom, means_het = means_hom.mean(0).detach().numpy(), means_het.mean(0).detach().numpy()
    vars_hom, vars_het = vars_hom.detach().numpy(), vars_het.detach().numpy()
    entropy_hom, entropy_het = entropy_hom.detach().numpy(), entropy_het.detach().numpy()
    ig_hom, ig_het = ig_hom.detach().numpy(), ig_het.detach().numpy()
    z_hom, z_het = np.sqrt(vars_hom), np.sqrt(vars_het)

    # Plot: i) homoskedastic fit, ii) heteroskedastic fit, iii) entropy, iv) information gain
    fig, axes = plt.subplots(1, 4, figsize=(13, 3))

    # 1st plot: Homoskedastic fit
    axes[0].scatter(x_hom, y_hom, color='black', s=6)
    axes[0].plot(x_test, means_hom, color='blue', label='$\hat{y}$')
    axes[0].fill_between(x_test.squeeze(), means_hom.squeeze() - 2 * z_hom, means_hom.squeeze() + 2 * z_hom,
                            color='blue', alpha=0.2, label='95% CI')
    axes[0].legend()
    axes[0].grid()
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')

    # Draw dashed grey line at -3 and 3
    axes[0].axvline(x=-3, color='black', linestyle='--', linewidth=1, alpha=0.7)
    axes[0].axvline(x=3, color='black', linestyle='--', linewidth=1, alpha=0.7)

    # 2nd plot: Entropy and IG on homoskedastic fit
    axes[1].plot(x_test, entropy_hom, color='red', label='Pred $H[\cdot]$')
    axes[1].plot(x_test, ig_hom, color='darkgreen', label='EIG')
    axes[1].legend()
    axes[1].grid()
    axes[1].set_xlabel('x')

    # Draw dashed grey line at -3 and 3
    axes[1].axvline(x=-3, color='black', linestyle='--', linewidth=1, alpha=0.7)
    axes[1].axvline(x=3, color='black', linestyle='--', linewidth=1, alpha=0.7)
    axes[1].axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.7)

    # 3rd plot: Heteroskedastic fit
    axes[2].scatter(x_het, y_het, color='black', s=6)
    axes[2].plot(x_test, means_het, color='blue', label='$\hat{y}$')
    axes[2].fill_between(x_test.squeeze(), means_het.squeeze() - 2 * z_het, means_het.squeeze() + 2 * z_het,
                            color='blue', alpha=0.2, label='95% CI')
    axes[2].legend()
    axes[2].grid()
    axes[2].set_xlabel('x')
    axes[2].set_ylabel('y')

    # Draw dashed grey line at -3 and 3
    axes[2].axvline(x=-3, color='black', linestyle='--', linewidth=1, alpha=0.7)
    axes[2].axvline(x=3, color='black', linestyle='--', linewidth=1, alpha=0.7)

    # 4th plot: Entropy and IG on heteroskedastic fit
    axes[3].plot(x_test, entropy_het, color='red', label='Pred $H[\cdot]$')
    axes[3].plot(x_test, ig_het, color='darkgreen', label='EIG')
    axes[3].legend()
    axes[3].grid()
    axes[3].set_xlabel('x')

    # Draw dashed grey line at -3 and 3
    axes[3].axvline(x=-3, color='black', linestyle='--', linewidth=1, alpha=0.7)
    axes[3].axvline(x=3, color='black', linestyle='--', linewidth=1, alpha=0.7)
    axes[3].axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.7)

    # Add titles for each pair of plots
    plt.figtext(0.25, 0.95, 'Homoskedastic', ha='center', va='center', fontsize=12)
    plt.figtext(0.75, 0.95, 'Heteroskedastic', ha='center', va='center', fontsize=12)

    plt.tight_layout()
    # Adjust the layout to make room for the titles
    plt.subplots_adjust(top=0.85)
    plt.show()

    # Save the figure
    fig.savefig('plots/heteroskedasticity.pdf', dpi=300, bbox_inches='tight')
