"""
Script to train the routing model on the synthetic dataset.
"""
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Add the project root directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.utils.simulate_multitask_gp import simulate_multitask_dataset
from src.models.modality_router import RoutingModel


def heteroscedastic_loss(pred, target, log_var):
    """Heteroscedastic loss function for uncertainty-aware regression.
    Args:
        pred (Tensor): Model predictions
        target (Tensor): Ground truth targets
        log_var (Tensor): Log variance predictions
    Returns:
        Tensor: Loss value
    """
    return 0.5 * torch.exp(-log_var) * (pred - target) ** 2 + 0.5 * log_var


def train_epoch(model, train_loader, optimizer, device, use_heteroscedastic=False):
    """Train the model for one epoch.
    Args:
        model (nn.Module): The routing model
        train_loader (DataLoader): Training data loader
        optimizer (Optimizer): Optimizer
        device (torch.device): Device to run on
        use_heteroscedastic (bool): Whether to use heteroscedastic loss
    Returns:
        float: Average training loss
    """
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        x_numeric, x_text, y1, y2 = [b.to(device) for b in batch]
        
        optimizer.zero_grad()
        
        # Forward pass
        pred1, pred2, mod_probs, task_probs = model(
            x_numeric, x_text
        )
        
        # Compute loss
        if use_heteroscedastic:
            # For heteroscedastic loss, we need to predict variance too
            # This is a simplified version - you might want to modify the model
            # to output variance predictions
            log_var1 = torch.zeros_like(pred1)
            log_var2 = torch.zeros_like(pred2)
            loss1 = heteroscedastic_loss(pred1, y1, log_var1).mean()
            loss2 = heteroscedastic_loss(pred2, y2, log_var2).mean()
        else:
            loss1 = F.mse_loss(pred1, y1)
            loss2 = F.mse_loss(pred2, y2)
        
        loss = loss1 + loss2
        
        # Add regularization for routing probabilities (entropy regularization)
        entropy_modality = 0.1 * (-(mod_probs * torch.log(mod_probs + 1e-10)).sum(dim=1).mean())
        entropy_task = 0.1 * sum(-(p * torch.log(p + 1e-10)).sum(dim=1).mean() for p in task_probs)
        entropy_loss = entropy_modality + entropy_task
        loss += entropy_loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)


def evaluate(model, test_loader, device):
    """Evaluate the model on the test set.
    Args:
        model (nn.Module): The routing model
        test_loader (DataLoader): Test data loader
        device (torch.device): Device to run on
    Returns:
        dict: Evaluation metrics and predictions
    """
    model.eval()
    total_loss = 0
    all_preds1 = []
    all_preds2 = []
    all_mod_probs = []
    all_task_probs = []
    
    with torch.no_grad():
        for batch in test_loader:
            x_numeric, x_text, y1, y2 = [b.to(device) for b in batch]
            
            # Forward pass
            pred1, pred2, mod_probs, task_probs = model(
                x_numeric, x_text
            )
            
            # Compute loss
            loss1 = F.mse_loss(pred1, y1)
            loss2 = F.mse_loss(pred2, y2)
            loss = loss1 + loss2
            
            total_loss += loss.item()
            
            # Store predictions and routing probabilities
            all_preds1.append(pred1.cpu())
            all_preds2.append(pred2.cpu())
            all_mod_probs.append(mod_probs.cpu())
            all_task_probs.extend([p.cpu() for p in task_probs])
    
    # Concatenate all predictions and probabilities
    all_preds1 = torch.cat(all_preds1, dim=0)
    all_preds2 = torch.cat(all_preds2, dim=0)
    all_mod_probs = torch.cat(all_mod_probs, dim=0)
    
    # Compute RMSE
    rmse1 = torch.sqrt(F.mse_loss(all_preds1, test_loader.dataset.tensors[2].cpu()))
    rmse2 = torch.sqrt(F.mse_loss(all_preds2, test_loader.dataset.tensors[3].cpu()))
    
    return {
        'loss': total_loss / len(test_loader),
        'rmse1': rmse1.item(),
        'rmse2': rmse2.item(),
        'predictions': (all_preds1, all_preds2),
        'modality_probs': all_mod_probs,
        'task_probs': all_task_probs
    }


def parse_args():
    """Parse command-line arguments for training script."""
    parser = argparse.ArgumentParser(description='Train routing model')
    parser.add_argument('--n_samples', type=int, default=1000,
                      help='Number of samples to generate')
    parser.add_argument('--batch_size', type=int, default=32,
                      help='Batch size')
    parser.add_argument('--hidden_dim', type=int, default=32,
                      help='Hidden dimension')
    parser.add_argument('--lr', type=float, default=0.001,
                      help='Learning rate')
    parser.add_argument('--n_epochs', type=int, default=100,
                      help='Number of epochs')
    parser.add_argument('--use_heteroscedastic', action='store_true',
                      help='Use heteroscedastic loss')
    parser.add_argument('--output_dir', type=str, default='experiments/results',
                      help='Directory to save results')
    return parser.parse_args()


def main():
    """Main function to run the training pipeline."""
    args = parse_args()
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Generate data
    data = simulate_multitask_dataset(
        n_samples=args.n_samples
    )
    X_numeric = data['X_numeric']
    X_textual = data['X_text']
    y1 = data['y1']
    y2 = data['y2']

    # Normalize X_numeric and X_textual
    scaler_numeric = StandardScaler()
    scaler_textual = StandardScaler()
    X_numeric = scaler_numeric.fit_transform(X_numeric)
    X_textual = scaler_textual.fit_transform(X_textual)

    # Convert to torch tensors
    X_numeric = torch.FloatTensor(X_numeric)
    X_textual = torch.FloatTensor(X_textual)
    y1 = torch.FloatTensor(y1).reshape(-1, 1)
    y2 = torch.FloatTensor(y2).reshape(-1, 1)

    # Split into train and test
    train_idx, test_idx = train_test_split(
        np.arange(len(X_numeric)),
        test_size=0.2,
        random_state=42
    )

    # Create datasets
    train_dataset = TensorDataset(
        X_numeric[train_idx],
        X_textual[train_idx],
        y1[train_idx],
        y2[train_idx]
    )
    test_dataset = TensorDataset(
        X_numeric[test_idx],
        X_textual[test_idx],
        y1[test_idx],
        y2[test_idx]
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False
    )

    # Create model
    model = RoutingModel(
        input_dim_numeric=X_numeric.shape[1],
        input_dim_text=X_textual.shape[1],
        hidden_dim=args.hidden_dim
    ).to(device)
    
    # Create optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Training loop
    best_test_loss = float('inf')
    train_losses = []
    test_losses = []
    epochs_no_improve = 0
    patience = 10
    for epoch in range(args.n_epochs):
        # Train for one epoch
        train_loss = train_epoch(
            model, train_loader, optimizer, device,
            use_heteroscedastic=args.use_heteroscedastic
        )
        
        # Evaluate on test set
        test_results = evaluate(model, test_loader, device)
        
        # Print progress
        print(f'Epoch {epoch+1}/{args.n_epochs}:')
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Test Loss: {test_results["loss"]:.4f}')
        print(f'Test RMSE (Task 1): {test_results["rmse1"]:.4f}')
        print(f'Test RMSE (Task 2): {test_results["rmse2"]:.4f}')
        print()
        
        # Save best model checkpoint
        if test_results['loss'] < best_test_loss:
            best_test_loss = test_results['loss']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'test_results': test_results
            }, os.path.join(args.output_dir, 'best_model.pt'))
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        
        train_losses.append(train_loss)
        test_losses.append(test_results['loss'])
        
        # Early stopping
        if epochs_no_improve >= patience:
            print('Early stopping triggered.')
            break
    
    print('Training complete.')


if __name__ == '__main__':
    main() 