# train_spd_logchol.py
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from torch.utils.data import TensorDataset, DataLoader
from generate_corre import generate_corre_dataset
from LFRCov_cholesky import LFRCovCholesky_torch

# =========================================
#      Numerics & SPD / log–Cholesky
# =========================================

def to_device(x, device):
    if isinstance(x, torch.Tensor):
        return x.to(device)
    if isinstance(x, (list, tuple)):
        return [to_device(t, device) for t in x]
    return x

def log_cholesky_distance(S1, S2):
    """
    Log–Cholesky distance using R-style upper Cholesky factor:
    d(S1,S2) = sqrt( ||sUT(U1) - sUT(U2)||_F^2 + ||log diag(U1) - log diag(U2)||_2^2 ),
    where U = chol(S) is upper-triangular (we take U = L^T with L = torch cholesky).
    """
    dtype0 = S1.dtype
    S1 = S1.to(torch.float64)
    S2 = S2.to(torch.float64)

    # Ensure symmetry
    S1 = 0.5 * (S1 + S1.transpose(-1, -2))
    S2 = 0.5 * (S2 + S2.transpose(-1, -2))
    
    # Cholesky decomposition
    L1 = torch.linalg.cholesky(S1)
    L2 = torch.linalg.cholesky(S2)
    
    # Convert to upper triangular (R-style)
    U1 = L1.transpose(-1, -2)
    U2 = L2.transpose(-1, -2)

    # Strictly upper parts (exclude diagonal)
    sUT1 = torch.triu(U1, diagonal=1)
    sUT2 = torch.triu(U2, diagonal=1)
    
    # Off-diagonal distance
    off_dist_sq = torch.norm(sUT1 - sUT2, p='fro') ** 2

    # Diagonal log distance
    d1 = torch.diagonal(U1, dim1=-2, dim2=-1)
    d2 = torch.diagonal(U2, dim1=-2, dim2=-1)
    
    logD_dist_sq = (torch.log(d1) - torch.log(d2)).pow(2).sum()

    # Total distance
    result = torch.sqrt(off_dist_sq + logD_dist_sq).to(dtype0)
    
    return result

def frechet_variance_log_cholesky(M_list):
    """Compute Fréchet variance using log-Cholesky distance"""
    n = len(M_list)
    if n <= 1:
        return torch.tensor(0.0, device=M_list[0].device)
    
    # Compute Fréchet mean using iterative algorithm in log-Cholesky space
    M_mean = compute_frechet_mean_log_cholesky(M_list)
    
    # Compute distances to mean
    dists = torch.stack([log_cholesky_distance(M, M_mean) for M in M_list])
    
    return dists.mean()

def compute_frechet_mean_log_cholesky(M_list):
    """Compute Fréchet mean using the same method as conditional means in generate_matrices.py"""
    n = len(M_list)
    if n == 1:
        return M_list[0]
    
    d = M_list[0].shape[0]
    
    # Calculate Fréchet mean using log-Cholesky metric
    # The minimizer is the arithmetic mean of sLT(L) and log(D(L))
    sLT_means = torch.zeros(d, d, device=M_list[0].device, dtype=M_list[0].dtype)
    logD_means = torch.zeros(d, d, device=M_list[0].device, dtype=M_list[0].dtype)
    
    for j in range(n):
        # Cholesky decomposition S = LL^T
        L = torch.linalg.cholesky(M_list[j])
        
        # Extract strictly lower triangular part sLT(L)
        sLT = torch.tril(L, diagonal=-1)
        sLT_means += sLT
        
        # Extract diagonal and take log
        D = torch.diag(torch.diag(L))
        logD = torch.diag(torch.log(torch.diag(L)))
        logD_means += logD
    
    # Take arithmetic means
    sLT_means = sLT_means / n
    logD_means = logD_means / n
    
    # Reconstruct the Fréchet mean matrix
    # We need to exponentiate logD_means since we took log earlier
    expD_means = torch.diag(torch.exp(torch.diag(logD_means)))
    L_mean = sLT_means + expD_means
    S_mean = L_mean @ L_mean.T
    
    return S_mean

def lfr_loss_selffit(X, M_list, theta, h):
    """LFR loss with self-fitting (no separate validation)"""
    device = X.device
    n = X.shape[0]
    
    # Compute Z = <X_i, theta_i> for each sample (theta is batch-aligned)
    Z = (X * theta).sum(dim=1, keepdim=True)
    
    # M_list is already a list of n matrices (n, q, q)
    # Use LFR with log-Cholesky metric
    result = LFRCovCholesky_torch(
        x=Z, M=M_list, xout=Z, h=h, 
        metric="log_cholesky", dtype=torch.float64
    )
    
    M_hat_list = result['Mout']
    
    # Compute log-Cholesky distances
    distances = []
    for i in range(n):
        dist = log_cholesky_distance(M_list[i], M_hat_list[i])
        distances.append(dist)
    
    return torch.stack(distances).mean()

# =========================================
#      Neural Network Model
# =========================================

class ThetaMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, dropout_rate=0.3):
        super(ThetaMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.act1 = nn.LeakyReLU()
        self.drop1 = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.act2 = nn.LeakyReLU()
        self.drop2 = nn.Dropout(dropout_rate)
        
        self.fc3 = nn.Linear(hidden_dim, input_dim)
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.fc3.weight, nonlinearity='leaky_relu')
        for layer in (self.fc1, self.fc2, self.fc3):
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.ln1(x)
        x = self.act1(x)
        x = self.drop1(x)
        
        x = self.fc2(x)
        x = self.ln2(x)
        x = self.act2(x)
        x = self.drop2(x)
        
        x = self.fc3(x)
        
        # Normalize to unit vector and fix sign ambiguity
        norm = torch.norm(x, dim=1, keepdim=True)
        theta = x / (norm + 1e-8)
        sign = torch.where(theta[:, :1] < 0, -1.0, 1.0)
        theta = theta * sign
        
        return theta

# =========================================
#      Training Function
# =========================================

def train_model():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Parameters
    n = 2500
    p = 4
    q = 3
    batch_size = 64
    lr = 0.01
    num_epochs = 1000
    patience = 5
    delta = 1e-4  # Minimum improvement threshold for early stopping
    lambda_reg = 0.01  # Regularization hyperparameter for 1/h term
    
    # Generate data
    print("Generating data...")
    X, M_list, _ = generate_corre_dataset(n=n, m=1000, q=q, random_seed=42, dtype=torch.float32)
    X = X.to(device)
    M_list = [M.to(device) for M in M_list]
    
    # Split data
    perm_cpu = torch.randperm(n)
    perm = perm_cpu.to(device)
    X = X[perm]
    M_list = [M_list[idx] for idx in perm_cpu.tolist()]
    
    train_size = int(0.4 * n)
    val_size = int(0.4 * n)
    test_size = n - train_size - val_size
    
    X_train = X[:train_size]
    X_val = X[train_size:train_size + val_size]
    X_test = X[train_size + val_size:]
    
    M_train = M_list[:train_size]
    M_val = M_list[train_size:train_size + val_size]
    M_test = M_list[train_size + val_size:]
    
    # Normalize predictors using training statistics
    X_mean = X_train.mean(dim=0, keepdim=True)
    X_std = X_train.std(dim=0, keepdim=True).clamp_min(1e-8)
    X_train = (X_train - X_mean) / X_std
    X_val = (X_val - X_mean) / X_std
    X_test = (X_test - X_mean) / X_std
    
    # Compute global Fréchet variance for normalization
    with torch.no_grad():
        frechet_var = frechet_variance_log_cholesky(M_train)
        frechet_var = frechet_var + 1e-8
    print(f"Training Fréchet variance: {frechet_var.item():.6f}")
    
    # Create model
    model = ThetaMLP(input_dim=p, hidden_dim=64, dropout_rate=0.2).to(device)
    # Learnable bandwidth parameter, clamped to [0.1, 1.0]
    h_param = nn.Parameter(torch.tensor(0.5, device=device))
    optimizer = optim.Adam(list(model.parameters()) + [h_param], lr=lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
    
    # Training loop
    train_losses = []
    val_losses = []
    val_epochs = []
    avg_log_chol_distances = []  # Track average log-Cholesky distance per epoch
    best_val_loss = float('inf')
    best_model_state = None
    best_h = None
    patience_counter = 0
    
    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        # Mini-batch training
        for i in range(0, len(X_train), batch_size):
            end_idx = min(i + batch_size, len(X_train))
            Xb = X_train[i:end_idx]
            Mb = M_train[i:end_idx]
            
            optimizer.zero_grad()
            
            # Forward pass
            theta_b = model(Xb)
            h_value = torch.clamp(h_param, min=0.1, max=1.0)
            lfr = lfr_loss_selffit(Xb, Mb, theta_b, h_value)
            norm_loss = lfr / frechet_var
            loss = norm_loss + (1.0 / h_value) * lambda_reg
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches
        train_losses.append(avg_loss)
        
        # Validation evaluation
        with torch.no_grad():
            model.eval()
            theta_val = model(X_val)
            h_eval = torch.clamp(h_param.detach(), min=0.1, max=1.0)
            val_lfr = lfr_loss_selffit(X_val, M_val, theta_val, h_eval)
            val_loss = (val_lfr / frechet_var) + (1.0 / h_eval) * lambda_reg
            val_losses.append(val_loss.item())
            val_epochs.append(epoch)
            
            # Calculate average log-Cholesky distance on validation set (every epoch)
            theta_train_all = model(X_train)
            result_val = LFRCovCholesky_torch(
                x=(X_train * theta_train_all).sum(dim=1),
                M=M_train,
                xout=(X_val * theta_val).sum(dim=1),
                h=h_eval,
                metric="log_cholesky",
                dtype=torch.float64
            )
            M_pred_val = result_val['Mout']
            val_distances = [log_cholesky_distance(M_pred_val[i], M_val[i]).item() for i in range(len(M_val))]
            avg_val_dist = sum(val_distances) / len(val_distances)
            avg_log_chol_distances.append(avg_val_dist)
        
        # Learning rate scheduling
        scheduler.step()
        
        # Standard early stopping criterion:
        # Stop when val_loss(t) > min val_loss(s) - δ for 'patience' consecutive epochs
        current_val_loss = val_loss.item()
        
        if current_val_loss < best_val_loss - delta:
            # Validation loss improved by at least delta
            best_val_loss = current_val_loss
            best_model_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_h = h_param.detach().cpu().clone()
            patience_counter = 0
        else:
            # Validation loss did not improve by at least delta
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch} (val_loss not improved by {delta} for {patience} epochs)")
            break
        
        if epoch % 20 == 0:
            # Compute mean theta on training set (column mean, i.e., mean across samples)
            with torch.no_grad():
                theta_train_all = model(X_train)  # (n_train, p)
                theta_mean = theta_train_all.mean(dim=0)  # (p,) - mean across samples for each component
                
            print(f"Epoch {epoch}, Train Loss: {avg_loss:.6f}, Val Loss: {val_loss.item():.6f}, LR: {optimizer.param_groups[0]['lr']:.6f}, h: {h_value.item():.4f}")
            print(f"  Mean theta: {theta_mean.cpu().numpy()}")
            print(f"  Avg log-Cholesky distance (val): {avg_log_chol_distances[-1]:.6f}")
    
    # Final evaluation
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        model.to(device)
        if best_h is not None:
            h_param.data = best_h.to(device)
    
    model.eval()
    with torch.no_grad():
        theta_final = model(X_test)
        h_eval = torch.clamp(h_param.detach(), min=0.1, max=1.0)
        
        # Calculate test loss
        test_lfr = lfr_loss_selffit(X_test, M_test, theta_final, h_eval)
        test_loss = (test_lfr / frechet_var) + (1.0 / h_eval) * lambda_reg
        print(f"Final test loss: {test_loss.item():.6f}")
        
        # Calculate error statistics
        print("\nCalculating error statistics...")
        M_pred_list = []
        for i in range(len(X_test)):
            Z_test = (X_test[i:i+1] * theta_final[i:i+1]).sum(dim=1, keepdim=True)
            result = LFRCovCholesky_torch(
                x=X_train @ theta_final[i], M=M_train, xout=Z_test, h=h_eval,
                metric="log_cholesky", dtype=torch.float64
            )
            M_pred_list.append(result['Mout'][0])
        
        # Calculate log-Cholesky distances
        distances = []
        distances_squared = []
        for i in range(len(M_test)):
            dist = log_cholesky_distance(M_pred_list[i], M_test[i])
            distances.append(dist.item())
            distances_squared.append(dist.item() ** 2)
        
        mean_distance = np.mean(distances)
        mean_distance_squared = np.mean(distances_squared)
        
        print(f"Mean log-Cholesky distance: {mean_distance:.6f}")
        print(f"Mean log-Cholesky distance squared: {mean_distance_squared:.6f}")
        print(f"Learned bandwidth h: {h_eval.item():.6f}")
        
        # Save results
        results = {
            'distances': distances,
            'distances_squared': distances_squared,
            'mean_distance': mean_distance,
            'mean_distance_squared': mean_distance_squared
        }
        
        # Save to files
        np.savetxt('spd/log_cholesky_distances.csv', distances, delimiter=',')
        np.savetxt('spd/log_cholesky_distances_squared.csv', distances_squared, delimiter=',')
        
        with open('spd/error_statistics.txt', 'w') as f:
            f.write(f"Mean log-Cholesky distance: {mean_distance:.6f}\n")
            f.write(f"Mean log-Cholesky distance squared: {mean_distance_squared:.6f}\n")
            f.write(f"Standard deviation: {np.std(distances):.6f}\n")
            f.write(f"Min distance: {np.min(distances):.6f}\n")
            f.write(f"Max distance: {np.max(distances):.6f}\n")
        
        print("Results saved to spd/ directory")
        
        # Plot training curve
        plt.figure(figsize=(12, 6))
        ax1 = plt.subplot(1, 2, 1)
        ax1.plot(range(len(train_losses)), train_losses, label='Train Loss')
        if val_losses:
            ax1.plot(val_epochs, val_losses, label='Val Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_yscale('log')
        ax1.grid(True)
        ax1.legend()
        
        # Plot average log-Cholesky distance
        ax2 = plt.subplot(1, 2, 2)
        if avg_log_chol_distances:
            ax2.plot(val_epochs, avg_log_chol_distances, label='Avg Log-Cholesky Distance', color='green')
        ax2.set_title('Average Log-Cholesky Distance (Validation)')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Distance')
        ax2.grid(True)
        ax2.legend()
        
        plt.tight_layout()
        plt.savefig('spd/loss_curves.png')
        plt.close()
        
        # Plot theta values with true theta overlay
        theta_np = theta_final.cpu().numpy()
        # True theta from generate_corre.py: normalized [0.1, 0.5, 0, -0.1]
        theta_true = np.array([0.1, 0.5, 0.0, -0.1])
        theta_true = theta_true / np.linalg.norm(theta_true)
        
        plt.figure(figsize=(8, 6))
        plt.boxplot([theta_np[:, i] for i in range(p)], tick_labels=[f'Theta {i+1}' for i in range(p)])
        # Plot true theta as red dots
        for i in range(p):
            plt.scatter(i + 1, theta_true[i], color='red', s=100, zorder=5, label='True' if i == 0 else None)
        plt.title('Theta Distribution (Boxplots) with True Values')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.5)
        plt.savefig('spd/theta_boxplots.png')
        plt.close()
        
        print("Training completed successfully!")

if __name__ == "__main__":
    train_model()
