#!/usr/bin/env python3
"""
Example script demonstrating the usage of the Multi-Task Learning Influence Analysis package.

This script shows how to:
1. Create synthetic data
2. Train a multi-task model
3. Compute influence attribution scores
4. Analyze the results
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List

from utils import (
    create_synthetic_dataset, 
    create_synthetic_model, 
    train_synthetic_model,
    set_seed,
    plot_training_curves
)
from attribution import compute_influence_attribution
from models import get_model


def main():
    """Main example function."""
    print("Multi-Task Learning Influence Analysis Example")
    print("=" * 50)
    
    # Set random seed for reproducibility
    set_seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Configuration
    num_samples = 1000
    input_dim = 20
    num_tasks = 2
    num_classes = 2
    hidden_dim = 64
    
    print(f"\nConfiguration:")
    print(f"  Number of samples: {num_samples}")
    print(f"  Input dimension: {input_dim}")
    print(f"  Number of tasks: {num_tasks}")
    print(f"  Number of classes: {num_classes}")
    print(f"  Hidden dimension: {hidden_dim}")
    
    # Step 1: Create synthetic data
    print(f"\nStep 1: Creating synthetic dataset...")
    data_loaders = create_synthetic_dataset(
        num_samples=num_samples,
        input_dim=input_dim,
        num_tasks=num_tasks,
        num_classes=num_classes,
        noise_level=0.01,
        seed=42,
        device=device
    )
    
    print(f"  Created data loaders for {len(data_loaders)} tasks")
    for task_name, loaders in data_loaders.items():
        print(f"    {task_name}: {len(loaders['train'].dataset)} train, "
              f"{len(loaders['val'].dataset)} val, {len(loaders['test'].dataset)} test samples")
    
    # Step 2: Create model
    print(f"\nStep 2: Creating multi-task model...")
    model = create_synthetic_model(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_tasks=num_tasks,
        num_classes=num_classes
    )
    
    print(f"  Model created with {sum(p.numel() for p in model.parameters())} parameters")
    
    # Step 3: Train model
    print(f"\nStep 3: Training model...")
    train_loaders = {name: loader['train'] for name, loader in data_loaders.items()}
    val_loaders = {name: loader['val'] for name, loader in data_loaders.items()}
    
    results = train_synthetic_model(
        model=model,
        train_loaders=train_loaders,
        val_loaders=val_loaders,
        num_epochs=10,
        learning_rate=0.01,
        weight_decay=1e-4,
        device=device,
        verbose=True
    )
    
    print(f"  Training completed!")
    print(f"  Final train accuracy: {results['train_accuracies'][-1]:.4f}")
    print(f"  Final val accuracy: {results['val_accuracies'][-1]:.4f}")
    
    # Step 4: Plot training curves
    print(f"\nStep 4: Plotting training curves...")
    plot_training_curves(
        train_losses=results['train_losses'],
        val_losses=results['val_losses'],
        train_accuracies=results['train_accuracies'],
        val_accuracies=results['val_accuracies']
    )
    
    # Step 5: Prepare data for influence computation
    print(f"\nStep 5: Preparing data for influence computation...")
    train_features = []
    train_labels = []
    
    for task_name, loader in data_loaders.items():
        # Get a batch of data
        features, labels = next(iter(loader['train']))
        train_features.append(features)
        train_labels.append(labels)
    
    print(f"  Prepared {len(train_features)} feature tensors")
    print(f"  Feature shapes: {[f.shape for f in train_features]}")
    print(f"  Label shapes: {[l.shape for l in train_labels]}")
    
    # Step 6: Compute influence attribution
    print(f"\nStep 6: Computing influence attribution...")
    
    # Create regularization parameter
    regularization_param = torch.randn(input_dim, device=device)
    
    # Compute influence attribution
    attribution_scores = compute_influence_attribution(
        models=[model],
        regularization_param=regularization_param,
        input_data=train_features,
        target_data=train_labels,
        input_dim=input_dim,
        loss_function='cross_entropy',
        computation_mode=0,  # Simple mode
        lambda_reg=[0.1] * num_tasks,
        output_dim=num_classes,
        num_samples=num_samples,
        device=device,
        link_function='logistic'
    )
    
    print(f"  Attribution computation completed!")
    print(f"  Number of attribution tensors: {len(attribution_scores)}")
    print(f"  Attribution tensor shapes: {[score.shape for score in attribution_scores]}")
    
    # Step 7: Analyze results
    print(f"\nStep 7: Analyzing results...")
    
    for task_idx, scores in enumerate(attribution_scores):
        print(f"  Task {task_idx}:")
        print(f"    Mean attribution: {scores.mean():.6f}")
        print(f"    Std attribution: {scores.std():.6f}")
        print(f"    Min attribution: {scores.min():.6f}")
        print(f"    Max attribution: {scores.max():.6f}")
        
        # Find most and least influential samples
        max_idx = scores.argmax().item()
        min_idx = scores.argmin().item()
        print(f"    Most influential sample index: {max_idx}")
        print(f"    Least influential sample index: {min_idx}")
    
    # Step 8: Visualize attribution scores
    print(f"\nStep 8: Visualizing attribution scores...")
    
    fig, axes = plt.subplots(1, num_tasks, figsize=(5*num_tasks, 4))
    if num_tasks == 1:
        axes = [axes]
    
    for task_idx, scores in enumerate(attribution_scores):
        # Flatten scores for visualization
        flat_scores = scores.flatten()
        
        axes[task_idx].hist(flat_scores, bins=50, alpha=0.7, edgecolor='black')
        axes[task_idx].set_xlabel('Attribution Score')
        axes[task_idx].set_ylabel('Frequency')
        axes[task_idx].set_title(f'Task {task_idx} Attribution Distribution')
        axes[task_idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nExample completed successfully!")
    print(f"Summary:")
    print(f"  - Created and trained a {num_tasks}-task model")
    print(f"  - Computed influence attribution scores")
    print(f"  - Analyzed attribution distributions")
    print(f"  - Generated visualizations")


def demonstrate_model_factory():
    """Demonstrate the model factory functionality."""
    print("\n" + "="*50)
    print("Model Factory Demonstration")
    print("="*50)
    
    # Available models
    model_names = [
        'FeatureEncoder',
        'ResNetFeatureEncoder', 
        'BertFeatureEncoder',
        'FullResNetEncoder',
        'MultiTaskResNet18',
        'SimpleConvolutionalModel',
        'TanhConvolutionalModel',
        'AlternativeConvolutionalModel'
    ]
    
    print("Available models:")
    for i, name in enumerate(model_names, 1):
        print(f"  {i}. {name}")
    
    # Create a few example models
    print(f"\nCreating example models...")
    
    try:
        # Feature encoder
        feature_encoder = get_model('FeatureEncoder')
        print(f"  FeatureEncoder: {sum(p.numel() for p in feature_encoder.parameters())} parameters")
        
        # Simple convolutional model
        conv_model = get_model('SimpleConvolutionalModel')
        print(f"  SimpleConvolutionalModel: {sum(p.numel() for p in conv_model.parameters())} parameters")
        
        # Test forward pass
        test_input = torch.randn(1, 3, 32, 32)
        with torch.no_grad():
            conv_output = conv_model(test_input)
        print(f"  Conv model output shape: {conv_output.shape}")
        
    except Exception as e:
        print(f"  Error creating models: {e}")


if __name__ == "__main__":
    main()
    demonstrate_model_factory() 