import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np


class CriticNetwork(nn.Module):
    """
    Simple 2-layer MLP that takes [x, y] as input and outputs a scalar T_theta(x,y).
    Feel free to add more layers or use a different architecture.
    """

    def __init__(self, input_dim=2, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),  # Single scalar output
        )

    def forward(self, x, y):
        # x, y shapes: (batch_size, 1)
        # We'll concatenate along dimension=1 -> shape (batch_size, 2)
        xy = torch.cat([x, y], dim=1)
        return self.net(xy)  # shape (batch_size, 1)


def mine_loss(critic_joint, critic_marginal):
    """
    Computes the Donsker-Varadhan lower bound:
      I = E[T(x,y)] - log E[e^(T(x,y))]
    critic_joint     = T_\theta(x,y) samples from p(x,y)
    critic_marginal = T_\theta(x,y) samples from p(x)p(y) (shuffled pairs)
    """
    # The mean of T over the joint
    e_t = torch.mean(critic_joint)
    # The log of the mean of e^{T} over the marginal
    e_exp_t = torch.mean(torch.exp(critic_marginal))
    return e_t - torch.log(e_exp_t)


def train_mine(x, y, critic, optimizer, batch_size=128, epochs=10):
    """
    x, y: 1D or 2D shapes (n_samples, 1).
    critic: an instance of CriticNetwork.
    optimizer: an optimizer for the critic's parameters.
    """
    x = x.detach()
    y = y.detach()

    with torch.enable_grad():
        for epoch in range(epochs):
            idx = torch.randperm(x.size(0))[:batch_size]
            x_joint = x[idx]
            y_joint = y[idx]

            # Sample a random batch from p(x)p(y) by shuffling y
            idx_m = torch.randperm(y.size(0))[:batch_size]
            x_marginal = x_joint  # keep x fixed
            y_marginal = y[idx_m]

            # Critic scores
            t_joint = critic(x_joint, y_joint)
            t_marginal = critic(x_marginal, y_marginal)

            loss = -mine_loss(
                t_joint, t_marginal
            )  # We want to maximize the bound => minimize negative
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # return estimate_mine(critic, x, y)  # Final estimate after training


def estimate_mine(critic, x, y):
    t_joint = critic(x, y)
    # We'll build a fully "shuffled" version for the entire dataset
    dataset_size = x.shape[0]
    idx_m = torch.randperm(dataset_size)
    x_marginal = x
    y_marginal = y[idx_m]
    t_marginal = critic(x_marginal, y_marginal)
    i_est = mine_loss(t_joint, t_marginal)

    return i_est


def mutual_information_mine(x, y):
    """
    x, y: 1D or 2D shapes (n_samples, 1).
    """
    # Create critic network and optimizer
    critic = CriticNetwork(input_dim=2, hidden_dim=64).to(x.device)
    optimizer = optim.Adam(critic.parameters(), lr=1e-3)

    # Train
    train_mine(x, y, critic, optimizer, batch_size=256, epochs=200)

    return estimate_mine(critic, x, y)
