# Modified from https://github.com/ermongroup/smile-mi-estimator and https://colab.research.google.com/github/google-research/google-research/blob/master/vbmi/vbmi_demo.ipynb

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from estimators import estimate_mutual_information
import os
import json
import matplotlib.colors as mc


# Check if CUDA or MPS is running
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = "cpu"

def rho_to_mi(dim, rho):
    """Obtain the ground truth mutual information from rho."""
    return -0.5 * np.log2(1 - rho**2) * dim


def mi_to_rho(dim, mi):
    """Obtain the rho for Gaussian give ground truth mutual information."""
    return np.sqrt(1 - 2**(-2.0 / dim * mi))


def mi_schedule(n_iter):
    """Generate schedule for increasing correlation over time."""
    mis = np.round(np.linspace(0.5, 5.5 - 1e-9, n_iter)) * 2.0
    return mis.astype(np.float32)


# Dataset maker
def sample_correlated_data(ncom=10, total_size=500, batch_size=128, info_tot=np.log(10), mlp_x=None, mlp_y=None, dev=None):
    """
    Generates synthetic data with known MI between latents Z_X and Z_Y.
    - If `total_size != ncom`, uses frozen MLPs -noisy or regular- to project latents to higher dimensions. Otherwise, either "nonlinearly - rotate" the data in low dimensions if mlps provided, or keep them as they're if not. 
    - dev put the batches on gpu if passed any argument
    """
    # Generate correlated latent variables Z_X, Z_Y
    rho = mi_to_rho(ncom, info_tot)
    cov = np.eye(2 * ncom)
    cov[ncom:, :ncom] = np.eye(ncom) * rho
    cov[:ncom, ncom:] = np.eye(ncom) * rho
    latents = np.random.multivariate_normal(np.zeros(2 * ncom), cov, batch_size)
    z_x, z_y = latents[:, :ncom], latents[:, ncom:]
    z_x_tensor = torch.tensor(z_x, dtype=torch.float32)
    z_y_tensor = torch.tensor(z_y, dtype=torch.float32)

    if total_size != ncom:
        if mlp_x and mlp_y:
            with torch.no_grad():
                # Use provided MLPs to project latents
                x = mlp_x(z_x_tensor)
                y = mlp_y(z_y_tensor)
        else:
            # Copy total_size/ncom times
            x = torch.tile(z_x_tensor,(int(total_size/ncom),))
            y = torch.tile(z_y_tensor,(int(total_size/ncom),))
    else:
        if mlp_x and mlp_y:
            with torch.no_grad():
                # Use provided MLPs to "rotate" latents
                x = mlp_x(z_x_tensor)
                y = mlp_y(z_y_tensor)
        else:
            # Identity mapping
            x, y = z_x_tensor, z_y_tensor
    
    if dev:
        return x.to(device), y.to(device)
    else:
        return x, y
        

def mut_info_optimized(x, y, threshold=1e-10):
    """
    Calculates -1/2 log det(rho) and removes the near zero directions
    """
    try:
        # Combine x and y column-wise (variables are columns)
        xy = np.hstack((x, y))
        
        # Compute joint covariance matrix once
        c_tot = np.cov(xy, rowvar=False)
        n_x = x.shape[1]  # Number of features in X
        n_y = y.shape[1]  # Number of features in Y
        
        # Extract C_x and C_y from the joint covariance matrix
        c_x = c_tot[:n_x, :n_x]
        c_y = c_tot[n_x:, n_x:]
        
        # Compute eigenvalues using eigh (faster for symmetric matrices)
        eig_tot = np.linalg.eigh(c_tot)[0]  # Returns sorted eigenvalues (ascending)
        eig_x = np.linalg.eigh(c_x)[0]
        eig_y = np.linalg.eigh(c_y)[0]
        
        # Threshold eigenvalues (avoid log(0))
        eig_tot_thr = np.maximum(eig_tot, threshold)
        eig_x_thr = np.maximum(eig_x, threshold)
        eig_y_thr = np.maximum(eig_y, threshold)
        
        # Compute log determinants
        logdet_tot = np.sum(np.log2(eig_tot_thr))
        logdet_x = np.sum(np.log2(eig_x_thr))
        logdet_y = np.sum(np.log2(eig_y_thr))
        
        # Mutual information
        info = 0.5 * (logdet_x + logdet_y - logdet_tot)
        return info if not np.isinf(info) else np.nan
    except np.linalg.LinAlgError:
        return np.nan


# Teacher model
class teacher(nn.Module):
    def __init__(self, dz, output_dim):
        super(teacher, self).__init__()
        self.dense1 = nn.Linear(dz, 1024)
        self.act = nn.Softplus()
        self.dense2 = nn.Linear(1024, output_dim)

    def forward(self, Z):
        x = self.act(self.dense1(Z))
        x = self.dense2(x)
        return x
        
# Dataset loader
class Dataset(torch.utils.data.Dataset):
  def __init__(self, X, Y):
        self.X = X
        self.Y = Y

  def __len__(self):
        return len(self.X)

  def __getitem__(self, index):
        return self.X[index], self.Y[index]


class mlp(nn.Module):
    def __init__(self, dim, hidden_dim, output_dim, layers, activation):
        """Create an mlp from the configurations."""
        super(mlp, self).__init__()
        activation_fn = {
            'relu': nn.ReLU,
            'sigmoid': nn.Sigmoid,
            'tanh': nn.Tanh,
            'leaky_relu': nn.LeakyReLU,
            'silu': nn.SiLU,
        }[activation]
    
        # Initialize the layers list
        seq = []
    
        # Input layer
        seq.append(nn.Linear(dim, hidden_dim))
        seq.append(activation_fn())
        nn.init.xavier_uniform_(seq[0].weight)  # Xavier initialization for input layer
    
        # Hidden layers
        for _ in range(layers):
            layer = nn.Linear(hidden_dim, hidden_dim)
            nn.init.xavier_uniform_(layer.weight)  # Xavier initialization for hidden layers
            seq.append(layer)
            seq.append(activation_fn())
    
        # Connect all together before the output
        self.base_network = nn.Sequential(*seq)
    
        # Output layer
        self.out = nn.Linear(hidden_dim, output_dim)
        
        # Initialize the layer with Xavier initialization
        nn.init.xavier_uniform_(self.out.weight)
    
    def forward(self, x):
        x = self.base_network(x)
        
        # Get output
        out = self.out(x)
        
        return out


class var_mlp(nn.Module):
    def __init__(self, dim, hidden_dim, output_dim, layers, activation):
        """Create a variational mlp from the configurations."""
        super(var_mlp, self).__init__()
        activation_fn = {
            'relu': nn.ReLU,
            'sigmoid': nn.Sigmoid,
            'tanh': nn.Tanh,
            'leaky_relu': nn.LeakyReLU,
            'silu': nn.SiLU,
        }[activation]
    
        # Initialize the layers list
        seq = []
    
        # Input layer
        seq.append(nn.Linear(dim, hidden_dim))
        seq.append(activation_fn())
        nn.init.xavier_uniform_(seq[0].weight)  # Xavier initialization for input layer
    
        # Hidden layers
        for _ in range(layers):
            layer = nn.Linear(hidden_dim, hidden_dim)
            nn.init.xavier_uniform_(layer.weight)  # Xavier initialization for hidden layers
            seq.append(layer)
            seq.append(activation_fn())
    
        # Connect all together before the output
        self.base_network = nn.Sequential(*seq)
    
        # Two heads for means and log variances
        self.fc_mu = nn.Linear(hidden_dim, output_dim)
        self.fc_logvar = nn.Linear(hidden_dim, output_dim)
        
        # Initialize the heads with Xavier initialization
        nn.init.xavier_uniform_(self.fc_mu.weight)
        nn.init.xavier_uniform_(self.fc_logvar.weight)
        
        # Normal distribution for sampling
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.to(device)
        self.N.scale = self.N.scale.to(device)
        
        # KL Divergence loss initialized to zero
        self.kl_loss = 0.0
        
        # Set limits for numerical stability
        self.logvar_min = -20  # Lower bound for logVar
        self.logvar_max = 20   # Upper bound for logVar

    def forward(self, x):
        x = self.base_network(x)
        
        # Get mean and log variance
        meanz = self.fc_mu(x)
        logVar = self.fc_logvar(x)

        # Clamp logVar to prevent extreme values
        logVar = torch.clamp(logVar, min=self.logvar_min, max=self.logvar_max)
        
        # Compute KL divergence loss
        kl_terms = 0.5 * (torch.square(meanz) + torch.exp(logVar) - 1 - logVar)
        self.kl_loss = torch.mean(torch.sum(kl_terms, dim=1))

        # Check for NaN in KL loss
        if torch.isnan(self.kl_loss):
            print("NaN detected in KL loss!")
            # Use a small default value instead of NaN
            self.kl_loss = torch.tensor(0.1, device=device, requires_grad=True)
        
        # Reparameterization trick
        epsilon = self.N.sample(meanz.shape)
        std = torch.exp(0.5 * logVar)
        samples = meanz + std * epsilon
        return [meanz, logVar, samples]



def log_prob_gaussian(x):
    return torch.sum(torch.distributions.Normal(0., 1.).log_prob(x), -1)


class decoder_INFO(nn.Module):
    def __init__(self, typeEstimator, mode="sep", baseline_fn=None):
        super(decoder_INFO, self).__init__()
        
        self.estimator = typeEstimator
        self.baseline_fn = baseline_fn
        self.mode = mode  # "sep" and "bi" use the same critic function

    def critic_fn(self, dataZX, dataZY, batch_size=None):
        if self.mode in ["sep", "bi"]:  
            return torch.matmul(dataZY, dataZX.t())
        elif self.mode == "concat":
            return torch.reshape(dataZX, [batch_size, batch_size]).t() # Here dataZX is really the final scores matrix
        else:
            raise ValueError("Invalid mode. Choose 'sep', 'bi', or 'concat'.")

    def forward(self, dataZX, dataZY, batch_size=None):
        return estimate_mutual_information(self.estimator, dataZX, dataZY,
                                           lambda x, y: self.critic_fn(x, y, batch_size),
                                           baseline_fn=self.baseline_fn)

def write_config(args):
  out_fn = "config.json"
  out_fp = os.path.join(args.save_dir, out_fn)
  with open(out_fp, 'w') as fh:
    json.dump(vars(args), fh)

def lighten_color(color, amount=0.5):
    """
    Lighten a color by mixing with white.
    
    Parameters:
    - color: str hex color (e.g., '#ff7f0e')
    - amount: float [0 to 1], where 0 = original, 1 = white
    
    Returns:
    - str hex color
    """
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = np.array(mc.to_rgb(c))
    w = np.array([1., 1., 1.])  # white
    rgb = (1 - amount) * c + amount * w
    return mc.to_hex(rgb)