import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from rffautoencoder3 import Encoder

# ==========================================
# NON-LINEAR DEEPIV (For Direct X)
# ==========================================

class DeepIVDensityNet(nn.Module):
    """Stage 1: P(D|Z). Outputs Gaussian parameters (mu, logvar)."""
    def __init__(self, z_dim, latent_dim, use_rff=True, rff_sigma=20):
        super().__init__()
        hidden_scale = [100, 50, 20]
        # Latent dim is latent_dim * 2 to provide both mean and log-variance
        self.encoder = Encoder(
            input_dim=z_dim, 
            hidden_dims=hidden_scale, 
            latent_dim=latent_dim * 2, 
            use_rff=use_rff, 
            rff_sigma=rff_sigma
        )
        self.latent_dim = latent_dim

    def forward(self, z):
        h = self.encoder(z)
        mu, logvar = torch.split(h, self.latent_dim, dim=1)
        return mu, logvar

class DeepIVOutcomeNet(nn.Module):
    """Stage 2: h(D) -> Y - Strictly Linear."""
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 1)

    def forward(self, d):
        return self.fc(d)

def train_deepiv(Z, X, Y, device, epochs=1000):
    """Full DeepIV training pipeline for complex/high-dim X."""
    z_dim, x_dim = Z.shape[1], X.shape[1]
    
    d_net = DeepIVDensityNet(z_dim=z_dim, latent_dim=x_dim).to(device)
    o_net = DeepIVOutcomeNet(latent_dim=x_dim).to(device)
    
    opt_d = torch.optim.RMSprop(d_net.parameters(), lr=5e-4)
    opt_o = torch.optim.RMSprop(o_net.parameters(), lr=5e-4)
    
    Z_t = torch.FloatTensor(Z).to(device)
    X_t = torch.FloatTensor(X).to(device)
    Y_t = torch.FloatTensor(Y).to(device).view(-1, 1)

    # Stage 1: P(X|Z)
    d_net.train()
    for _ in range(epochs):
        opt_d.zero_grad()
        mu, logvar = d_net(Z_t)
        # Gaussian NLL Loss
        loss = 0.5 * logvar + 0.5 * (X_t - mu)**2 / torch.exp(logvar)
        loss.mean().backward()
        opt_d.step()
    
    # Stage 2: h(X) -> Y
    o_net.train()
    d_net.eval()
    for _ in range(epochs):
        opt_o.zero_grad()
        with torch.no_grad():
            mu, logvar = d_net(Z_t)
        
        # Reparameterization Trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        x_sample = mu + eps * std
        
        y_pred = o_net(x_sample)
        loss = F.mse_loss(y_pred, Y_t)
        loss.backward()
        opt_o.step()
        
    return d_net, o_net

# ==========================================
# LINEAR DEEPIV (For IRAE Latent D)
# ==========================================

class LinearDeepIVDensityNet(nn.Module):
    """Stage 1: P(D|Z) - Strictly Linear."""
    def __init__(self, z_dim, latent_dim):
        super().__init__()
        self.fc_mu = nn.Linear(z_dim, latent_dim)
        self.fc_logvar = nn.Linear(z_dim, latent_dim)

    def forward(self, z):
        mu = self.fc_mu(z)
        logvar = self.fc_logvar(z)
        return mu, logvar

class LinearDeepIVOutcomeNet(nn.Module):
    """Stage 2: h(D) -> Y - Strictly Linear."""
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 1)

    def forward(self, d):
        return self.fc(d)
    
def train_linear_deepiv(Z, Treatment, Y, device, epochs=1000):
    """Linear DeepIV for low-dim discovered latents."""
    z_dim, t_dim = Z.shape[1], Treatment.shape[1]
    
    d_net = LinearDeepIVDensityNet(z_dim, t_dim).to(device)
    o_net = LinearDeepIVOutcomeNet(t_dim).to(device)
    
    # Higher LR for simpler linear manifold
    opt_d = torch.optim.RMSprop(d_net.parameters(), lr=1e-3)
    opt_o = torch.optim.RMSprop(o_net.parameters(), lr=1e-3)
    
    Z_t = torch.FloatTensor(Z).to(device)
    T_t = torch.FloatTensor(Treatment).to(device)
    Y_t = torch.FloatTensor(Y).to(device).view(-1, 1)

    # Stage 1: P(D|Z)
    d_net.train()
    for _ in range(epochs):
        opt_d.zero_grad()
        mu, logvar = d_net(Z_t)
        loss = 0.5 * logvar + 0.5 * (T_t - mu)**2 / torch.exp(logvar)
        loss.mean().backward()
        opt_d.step()
    
    # Stage 2: h(D) -> Y
    o_net.train()
    d_net.eval()
    for _ in range(epochs):
        opt_o.zero_grad()
        with torch.no_grad():
            mu, logvar = d_net(Z_t)
        
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        t_sample = mu + eps * std
        
        y_pred = o_net(t_sample)
        loss = F.mse_loss(y_pred, Y_t)
        loss.backward()
        opt_o.step()
        
    return d_net, o_net