"""
Training script for RAM++ ADE20K adapter

This script trains the MLP adapter to map RAM++ predictions to ADE20K classes.
The RAM++ backbone is frozen during training.
"""

import os
import sys
import json
import argparse
import time
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import average_precision_score, precision_recall_fscore_support
import warnings
warnings.filterwarnings("ignore")

# Add parent directory to path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from datasets.ade20k_dataset import ADE20KDataset, create_ade20k_dataloaders
from models.ram_plus_ade20k import RAM_plus_ADE20K, load_ram_plus_ade20k_pretrained
from losses.asymmetric_loss import AsymmetricLoss, CombinedLoss
from losses.class_balanced_focal_loss import ClassBalancedFocalLoss, create_ade20k_class_balanced_focal_loss


def calculate_ade20k_metrics(predictions, targets, threshold=0.5):
    """
    Calculate multi-label classification metrics for ADE20K
    
    Args:
        predictions: torch.Tensor [batch_size, 150] - sigmoid outputs
        targets: torch.Tensor [batch_size, 150] - ground truth labels
        threshold: float - classification threshold
    
    Returns:
        dict: metrics including mAP, precision, recall, F1, accuracy
    """
    # Convert to numpy for consistent calculation
    if torch.is_tensor(predictions):
        predictions = predictions.cpu().numpy()
    if torch.is_tensor(targets):
        targets = targets.cpu().numpy()
    
    pred_binary = (predictions > threshold).astype(np.float32)
    
    # Per-class Average Precision (mAP)
    aps = []
    for i in range(targets.shape[1]):
        if targets[:, i].sum() > 0:  # Only calculate AP if positive samples exist
            ap = average_precision_score(targets[:, i], predictions[:, i])
            aps.append(ap)
    
    mAP = np.mean(aps) if aps else 0.0
    
    # Per-class metrics
    epsilon = 1e-7
    tp = (pred_binary * targets).sum(0)
    fp = (pred_binary * (1 - targets)).sum(0)
    fn = ((1 - pred_binary) * targets).sum(0)
    tn = ((1 - pred_binary) * (1 - targets)).sum(0)
    
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)
    f1 = 2 * precision * recall / (precision + recall + epsilon)
    accuracy = (tp + tn) / (tp + tn + fp + fn + epsilon)
    
    # Overall metrics (only for classes with positive samples)
    valid_classes = targets.sum(0) > 0
    num_valid_classes = valid_classes.sum()
    
    if num_valid_classes > 0:
        overall_precision = precision[valid_classes].mean()
        overall_recall = recall[valid_classes].mean()
        overall_f1 = f1[valid_classes].mean()
        overall_accuracy = accuracy[valid_classes].mean()
    else:
        overall_precision = 0.0
        overall_recall = 0.0
        overall_f1 = 0.0
        overall_accuracy = 0.0
    
    return {
        'mAP': mAP,
        'precision': overall_precision,
        'recall': overall_recall,
        'f1': overall_f1,
        'accuracy': overall_accuracy,
        'num_valid_classes': int(num_valid_classes),
        'total_classes': targets.shape[1],
        'per_class_precision': precision,
        'per_class_recall': recall,
        'per_class_f1': f1,
        'per_class_accuracy': accuracy
    }


def train_epoch(model, dataloader, optimizer, criterion, device, epoch, writer=None):
    """Train for one epoch"""
    model.train()
    
    # Train adapter parameters + RAM++ final FC layer
    fc_patterns = ['ram_plus.fc', 'fc', 'ram_plus.tagging_head']
    for name, param in model.named_parameters():
        if 'adapter' in name or any(pattern in name for pattern in fc_patterns):
            param.requires_grad = True
        else:
            param.requires_grad = False
    
    running_loss = 0.0
    all_predictions = []
    all_targets = []
    
    for batch_idx, batch in enumerate(dataloader):
        images = batch['image'].to(device)
        targets = batch['labels'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, targets)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Collect predictions for metrics
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            all_predictions.append(probs.cpu())
            all_targets.append(targets.cpu())
        
        running_loss += loss.item()
        
        # Log progress
        if batch_idx % 50 == 0:
            if hasattr(criterion, 'get_individual_losses'):
                # For CombinedLoss, show breakdown
                asym_loss, bce_loss, _ = criterion.get_individual_losses(logits, targets)
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, '
                      f'Loss: {loss.item():.4f} (Asym: {asym_loss:.4f}, BCE: {bce_loss:.4f})')
            else:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
            
            if writer:
                global_step = epoch * len(dataloader) + batch_idx
                writer.add_scalar('Train/BatchLoss', loss.item(), global_step)
    
    # Calculate epoch metrics
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    
    metrics = calculate_ade20k_metrics(all_predictions, all_targets)
    avg_loss = running_loss / len(dataloader)
    
    print(f'Epoch {epoch} Training - Loss: {avg_loss:.4f}, mAP: {metrics["mAP"]:.4f}, '
          f'F1: {metrics["f1"]:.4f}, Valid Classes: {metrics["num_valid_classes"]}/150')
    
    if writer:
        writer.add_scalar('Train/Loss', avg_loss, epoch)
        writer.add_scalar('Train/mAP', metrics['mAP'], epoch)
        writer.add_scalar('Train/F1', metrics['f1'], epoch)
        writer.add_scalar('Train/Precision', metrics['precision'], epoch)
        writer.add_scalar('Train/Recall', metrics['recall'], epoch)
    
    return avg_loss, metrics


def validate_epoch(model, dataloader, criterion, device, epoch, writer=None):
    """Validate for one epoch"""
    model.eval()
    
    running_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            targets = batch['labels'].to(device)
            
            # Forward pass
            logits = model(images)
            loss = criterion(logits, targets)
            
            # Collect predictions
            probs = torch.sigmoid(logits)
            all_predictions.append(probs.cpu())
            all_targets.append(targets.cpu())
            
            running_loss += loss.item()
    
    # Calculate metrics
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    
    metrics = calculate_ade20k_metrics(all_predictions, all_targets)
    avg_loss = running_loss / len(dataloader)
    
    print(f'Epoch {epoch} Validation - Loss: {avg_loss:.4f}, mAP: {metrics["mAP"]:.4f}, '
          f'F1: {metrics["f1"]:.4f}, Valid Classes: {metrics["num_valid_classes"]}/150')
    
    if writer:
        writer.add_scalar('Val/Loss', avg_loss, epoch)
        writer.add_scalar('Val/mAP', metrics['mAP'], epoch)
        writer.add_scalar('Val/F1', metrics['f1'], epoch)
        writer.add_scalar('Val/Precision', metrics['precision'], epoch)
        writer.add_scalar('Val/Recall', metrics['recall'], epoch)
    
    return avg_loss, metrics


def save_checkpoint(model, optimizer, epoch, metrics, save_path):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'adapter_state_dict': model.adapter.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics,
        'model_config': {
            'freeze_backbone': model.freeze_backbone,
            'threshold': model.threshold,
            'num_classes': model.num_classes
        }
    }
    
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved: {save_path}")


def main():
    parser = argparse.ArgumentParser(description='Train RAM++ ADE20K adapter')
    
    # Model arguments
    parser.add_argument('--ram-checkpoint', type=str, required=True,
                       help='Path to pretrained RAM++ model')
    parser.add_argument('--adapter-checkpoint', type=str, default=None,
                       help='Path to pretrained adapter (for resuming)')
    parser.add_argument('--freeze-backbone', action='store_true', default=True,
                       help='Freeze RAM++ backbone during training')
    
    # Data arguments
    parser.add_argument('--ade20k-root', type=str, 
                       default='/home/gyf/iclr/recognize-anything/ADE20K',
                       help='Path to ADE20K dataset root')
    parser.add_argument('--image-size', type=int, default=384,
                       help='Input image size')
    parser.add_argument('--batch-size', type=int, default=16,
                       help='Batch size')
    parser.add_argument('--num-workers', type=int, default=4,
                       help='Number of data loading workers')
    
    # Training arguments
    parser.add_argument('--epochs', type=int, default=50,
                       help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=1e-4,
                       help='Learning rate')
    parser.add_argument('--weight-decay', type=float, default=1e-5,
                       help='Weight decay')
    parser.add_argument('--threshold', type=float, default=0.5,
                       help='Classification threshold')
    
    # Loss function arguments
    parser.add_argument('--loss-type', type=str, default='combined', 
                       choices=['asymmetric', 'combined', 'class_balanced_focal'],
                       help='Loss function type')
    parser.add_argument('--alpha', type=float, default=0.7,
                       help='Weight for asymmetric loss in combined loss')
    parser.add_argument('--beta', type=float, default=0.3,
                       help='Weight for BCE loss in combined loss')
    parser.add_argument('--gamma-neg', type=float, default=4.0,
                       help='Asymmetric loss gamma for negative samples')
    parser.add_argument('--gamma-pos', type=float, default=0.0,
                       help='Asymmetric loss gamma for positive samples')
    parser.add_argument('--clip', type=float, default=0.05,
                       help='Asymmetric loss probability clipping')
    
    # Class-Balanced Focal Loss arguments
    parser.add_argument('--focal-alpha', type=float, default=1.0,
                       help='Class-balanced focal loss alpha (scaling factor)')
    parser.add_argument('--focal-gamma', type=float, default=2.0,
                       help='Class-balanced focal loss gamma (focusing parameter)')
    parser.add_argument('--focal-beta', type=float, default=0.9999,
                       help='Class-balanced focal loss beta (re-weighting factor)')
    
    # Training data arguments
    parser.add_argument('--use-train-val', action='store_true',
                       help='Use both train and val sets for training (teacher requirement)')
    parser.add_argument('--unfreeze-layers', type=int, default=1,
                       help='Number of final layers to unfreeze (1=fc only, 2=fc+prev layer, etc.)')
    
    # Output arguments
    parser.add_argument('--output-dir', type=str, default='./logs/ade20k_training',
                       help='Output directory for logs and checkpoints')
    parser.add_argument('--save-freq', type=int, default=10,
                       help='Save checkpoint every N epochs')
    parser.add_argument('--device', type=str, default='cuda:0',
                       help='Device for training')
    
    args = parser.parse_args()
    
    # Setup output directory
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_dir = os.path.join(args.output_dir, f'run_{timestamp}')
    os.makedirs(output_dir, exist_ok=True)
    
    # Save arguments
    with open(os.path.join(output_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=2)
    
    print(f"Training RAM++ ADE20K adapter")
    print(f"Output directory: {output_dir}")
    print(f"Device: {args.device}")
    
    # Setup tensorboard
    writer = SummaryWriter(os.path.join(output_dir, 'tensorboard'))
    
    # Load model
    print(f"Loading model...")
    model = load_ram_plus_ade20k_pretrained(
        ram_plus_checkpoint=args.ram_checkpoint,
        ade20k_adapter_checkpoint=args.adapter_checkpoint,
        freeze_backbone=args.freeze_backbone,
        device=args.device
    )
    
    # Setup data loaders
    print(f"Setting up data loaders...")
    if args.use_train_val:
        print("Using BOTH train and val sets for training (teacher requirement)")
        # Create combined training dataset
        train_dataset = ADE20KDataset(
            root_dir=args.ade20k_root,
            split='train',
            image_size=args.image_size
        )
        val_dataset = ADE20KDataset(
            root_dir=args.ade20k_root,
            split='val', 
            image_size=args.image_size
        )
        
        # Combine datasets
        from torch.utils.data import ConcatDataset
        combined_dataset = ConcatDataset([train_dataset, val_dataset])
        
        train_loader = DataLoader(
            combined_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True
        )
        val_loader = None  # No separate validation when using combined data
        
        print(f"Combined dataset: {len(combined_dataset)} images")
        
    else:
        train_loader, val_loader = create_ade20k_dataloaders(
            root_dir=args.ade20k_root,
            image_size=args.image_size,
            batch_size=args.batch_size,
            num_workers=args.num_workers
        )
    
    print(f"Train set: {len(train_loader.dataset)} images")
    if val_loader is not None:
        print(f"Val set: {len(val_loader.dataset)} images")
    else:
        print("No separate validation set (using combined train+val)")
    
    # Debug: Print all parameter names to find FC layer
    print("\n🔍 Debugging parameter names (looking for FC layer):")
    for name, param in model.named_parameters():
        if 'fc' in name.lower():
            print(f"  Found FC parameter: {name} - {param.shape}")
    
    # Setup optimizer and loss with flexible layer unfreezing
    print(f"\n🔧 Setting up trainable parameters (unfreezing {args.unfreeze_layers} layers)...")
    
    # Collect all RAM++ parameters in reverse order (from output to input)
    ram_params = []
    for name, param in model.named_parameters():
        if name.startswith('ram_plus.'):
            ram_params.append((name, param))
    
    # Patterns for different layers (in order of preference for unfreezing)
    layer_patterns = [
        'ram_plus.fc',           # Final classification layer
        'ram_plus.tagging_head', # Tagging head layers
        'ram_plus.wordvec_proj', # Word vector projection
        'ram_plus.image_proj'    # Image projection
    ]
    
    trainable_params = []
    unfrozen_layers = 0
    
    # Always add adapter parameters
    for name, param in model.named_parameters():
        if 'adapter' in name:
            trainable_params.append(param)
    
    # Add RAM++ layers based on unfreeze_layers setting
    for pattern in layer_patterns:
        if unfrozen_layers >= args.unfreeze_layers:
            break
            
        for name, param in model.named_parameters():
            if pattern in name:
                trainable_params.append(param)
                print(f"  ✅ Adding to training: {name} - {param.shape}")
                unfrozen_layers += 1
                break
    
    optimizer = optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
    
    # Setup loss function
    if args.loss_type == 'class_balanced_focal':
        # Create dataset for calculating class frequencies
        if args.use_train_val:
            freq_dataset = combined_dataset
        else:
            freq_dataset = train_loader.dataset
            
        criterion = create_ade20k_class_balanced_focal_loss(
            freq_dataset,
            alpha=args.focal_alpha,
            gamma=args.focal_gamma,
            beta=args.focal_beta
        )
    elif args.loss_type == 'combined':
        criterion = CombinedLoss(
            alpha=args.alpha,
            beta=args.beta,
            gamma_neg=args.gamma_neg,
            gamma_pos=args.gamma_pos,
            clip=args.clip
        )
    else:
        criterion = AsymmetricLoss(
            gamma_neg=args.gamma_neg, 
            gamma_pos=args.gamma_pos, 
            clip=args.clip
        )
    
    print(f"\nOptimizer parameters: {sum(p.numel() for p in trainable_params):,}")
    
    # Print trainable parameters breakdown
    adapter_params = sum(p.numel() for n, p in model.named_parameters() if 'adapter' in n)
    ram_params = sum(p.numel() for n, p in model.named_parameters() 
                    if n.startswith('ram_plus.') and p.requires_grad)
    print(f"  - Adapter parameters: {adapter_params:,}")
    print(f"  - RAM++ parameters: {ram_params:,}")
    print(f"  - Total trainable: {adapter_params + ram_params:,}")
    
    # Training loop
    best_mAP = 0.0
    best_epoch = 0
    
    for epoch in range(1, args.epochs + 1):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{args.epochs}")
        print(f"{'='*60}")
        
        # Train
        train_loss, train_metrics = train_epoch(
            model, train_loader, optimizer, criterion, args.device, epoch, writer
        )
        
        # Validate (skip if using combined train+val)
        if val_loader is not None:
            val_loss, val_metrics = validate_epoch(
                model, val_loader, criterion, args.device, epoch, writer
            )
        else:
            # When using combined data, use training metrics as proxy
            val_loss, val_metrics = train_loss, train_metrics
            print(f"Epoch {epoch} - Using training metrics as validation (combined dataset)")
        
        # Save best model
        if val_metrics['mAP'] > best_mAP:
            best_mAP = val_metrics['mAP']
            best_epoch = epoch
            best_path = os.path.join(output_dir, 'best_model.pth')
            save_checkpoint(model, optimizer, epoch, val_metrics, best_path)
            print(f"New best model saved! mAP: {best_mAP:.4f}")
        
        # Save periodic checkpoint
        if epoch % args.save_freq == 0:
            checkpoint_path = os.path.join(output_dir, f'checkpoint_epoch_{epoch}.pth')
            save_checkpoint(model, optimizer, epoch, val_metrics, checkpoint_path)
        
        # Early stopping check
        if epoch - best_epoch > 20:
            print(f"Early stopping triggered. Best mAP: {best_mAP:.4f} at epoch {best_epoch}")
            break
    
    # Save final model
    final_path = os.path.join(output_dir, 'final_model.pth')
    save_checkpoint(model, optimizer, epoch, val_metrics, final_path)
    
    writer.close()
    
    print(f"\n{'='*60}")
    print(f"Training completed!")
    print(f"Best mAP: {best_mAP:.4f} at epoch {best_epoch}")
    print(f"Models saved in: {output_dir}")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()