"""
Training script for rotation-equivariant CNN using e2cnn.
Integrates with existing MNIST rotation data generator.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import math
import os
import sys

# Add current directory to path for imports
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from dataset_generator import load_mnist_rotation_datasets
from e2cnn_rotation_model import RotationEquivariantCNN_Simple, angle_to_cos_sin, cos_sin_to_angle, circular_mae_loss


def train_rotation_model(model, train_loader, val_loader, num_epochs=50, learning_rate=0.001, device='cpu'):
    """Train the rotation-equivariant model."""
    
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    
    train_losses = []
    val_losses = []
    val_maes = []
    
    print(f"Training on device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        for batch_idx, (images, angles_0_360) in enumerate(train_loader):
            images = images.to(device)
            angles_0_360 = angles_0_360.to(device)
            
            # Convert angles to [-180, 180] range
            angles_gt = ((angles_0_360 + 180) % 360) - 180
            
            # Convert to (cos θ, sin θ) representation
            target_cos_sin = angle_to_cos_sin(angles_gt.cpu().numpy())
            target_cos_sin = target_cos_sin.to(device).float()
            
            optimizer.zero_grad()
            
            # Forward pass
            pred_cos_sin = model(images)
            
            # Compute loss (MSE on cos/sin representation)
            loss = nn.MSELoss()(pred_cos_sin, target_cos_sin)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_mae = 0.0
        val_batches = 0
        
        with torch.no_grad():
            for images, angles_0_360 in val_loader:
                images = images.to(device)
                angles_0_360 = angles_0_360.to(device)
                
                # Convert angles to [-180, 180] range
                angles_gt = ((angles_0_360 + 180) % 360) - 180
                
                # Convert to (cos θ, sin θ) representation
                target_cos_sin = angle_to_cos_sin(angles_gt.cpu().numpy())
                target_cos_sin = target_cos_sin.to(device).float()
                
                # Forward pass
                pred_cos_sin = model(images)
                
                # Compute losses
                loss = nn.MSELoss()(pred_cos_sin, target_cos_sin)
                mae = circular_mae_loss(pred_cos_sin, target_cos_sin)
                
                val_loss += loss.item()
                val_mae += mae.item()
                val_batches += 1
        
        # Record metrics
        avg_train_loss = train_loss / train_batches
        avg_val_loss = val_loss / val_batches
        avg_val_mae = val_mae / val_batches
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        val_maes.append(avg_val_mae)
        
        print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val MAE: {avg_val_mae:.2f}°')
        
        # Update learning rate
        scheduler.step()
    
    return train_losses, val_losses, val_maes


def evaluate_model(model, test_loader, device='cpu'):
    """Evaluate the trained model."""
    
    model.eval()
    all_predictions = []
    all_targets = []
    all_errors = []
    
    with torch.no_grad():
        for images, angles_0_360 in test_loader:
            images = images.to(device)
            angles_0_360 = angles_0_360.to(device)
            
            # Convert angles to [-180, 180] range
            angles_gt = ((angles_0_360 + 180) % 360) - 180
            
            # Forward pass
            pred_cos_sin = model(images)
            pred_angles = cos_sin_to_angle(pred_cos_sin)
            
            # Compute circular errors
            errors = pred_angles - angles_gt
            errors = ((errors + 180) % 360) - 180
            errors = torch.abs(errors)
            
            all_predictions.extend(pred_angles.cpu().numpy())
            all_targets.extend(angles_gt.cpu().numpy())
            all_errors.extend(errors.cpu().numpy())
    
    # Compute final metrics
    mae = np.mean(all_errors)
    std = np.std(all_errors)
    
    print(f"Test MAE: {mae:.2f}° ± {std:.2f}°")
    print(f"Min error: {np.min(all_errors):.2f}°")
    print(f"Max error: {np.max(all_errors):.2f}°")
    
    return all_predictions, all_targets, all_errors


def plot_results(train_losses, val_losses, val_maes, test_errors=None):
    """Plot training results."""
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Training losses
    axes[0].plot(train_losses, label='Train Loss')
    axes[0].plot(val_losses, label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Validation MAE
    axes[1].plot(val_maes, label='Val MAE', color='red')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE (degrees)')
    axes[1].set_title('Validation MAE')
    axes[1].legend()
    axes[1].grid(True)
    
    # Test error distribution
    if test_errors is not None:
        axes[2].hist(test_errors, bins=50, alpha=0.7, edgecolor='black')
        axes[2].set_xlabel('Error (degrees)')
        axes[2].set_ylabel('Frequency')
        axes[2].set_title(f'Test Error Distribution\nMAE: {np.mean(test_errors):.2f}°')
        axes[2].grid(True)
    
    plt.tight_layout()
    plt.savefig('e2cnn_training_results.png', dpi=150, bbox_inches='tight')
    plt.show()


def main():
    """Main training function."""
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load datasets
    print("Loading MNIST rotation datasets...")
    train_loader, test_loader = load_mnist_rotation_datasets(
        rotation_range=(0.0, 360.0),
        augmentation_factor=1,
        batch_size=64,
        seed=42
    )
    
    # Create validation split from training data
    train_size = len(train_loader.dataset)
    val_size = train_size // 5  # 20% for validation
    train_size = train_size - val_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        train_loader.dataset, [train_size, val_size]
    )
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_loader.dataset)}")
    
    # Create model
    print("Creating rotation-equivariant model...")
    model = RotationEquivariantCNN_Simple(N=8)
    
    # Train model
    print("Starting training...")
    train_losses, val_losses, val_maes = train_rotation_model(
        model, train_loader, val_loader, 
        num_epochs=30, learning_rate=0.001, device=device
    )
    
    # Evaluate on test set
    print("Evaluating on test set...")
    predictions, targets, errors = evaluate_model(model, test_loader, device)
    
    # Plot results
    plot_results(train_losses, val_losses, val_maes, errors)
    
    # Save model
    torch.save(model.state_dict(), 'e2cnn_rotation_model.pth')
    print("Model saved as 'e2cnn_rotation_model.pth'")
    
    return model, predictions, targets, errors


if __name__ == "__main__":
    model, predictions, targets, errors = main()
