############################################
# ERM Training with SIEVE-Constructed Validation Set
#
# This file implements Algorithm 2 from the paper:
# "ERM Training with Model Selection Using the SIEVE-Constructed Validation Set"
#
# The SIEVE-constructed validation set can be plugged into any training
# pipeline that relies on group-labeled validation data for robust evaluation.
############################################

"""
Complete example for intergrating SIEVE selection with ERM training:

# Step 1: Run SIEVE algorithm
sieve = SIEVESelector(dataset_name='waterbirds', device='cuda')
validation_set = sieve.run_iterative_validation_selection(args, model, dataloaders, paths)

# Step 2: Train ERM with SIEVE validation set  
results = train_erm_with_sieve_validation(args, dataloaders, 
                                        validation_set['indices'], 
                                        validation_set['groups'], paths)
"""


import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, List, Tuple, Any
from collections import defaultdict
from tqdm import tqdm
import copy


class ERMTrainer:
    """
    ERM Trainer implementation.
    """
    def __init__(self, model, optimizer, scheduler, device):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.train_criterion = nn.CrossEntropyLoss(reduction='none')
        self.analysis_criterion = nn.CrossEntropyLoss(reduction='none')

    def train_epoch(self, dataloader, args, current_epoch, mode='train'):
        """
        Train for one epoch
        """
        self.model.train()
        running_losses = {}
        batch_losses = {}
        predictions_dict = {}
        epoch_loss = 0
        count = 0

        # If you want to change the weights of the selected examples, you can set args.selected_examples_weight
        # But it is an option for further exploration and not used in our current paper
        if hasattr(args, 'selected_examples_weight'):
            current_weight = args.selected_examples_weight
        else:
            current_weight = 1.0

        for inputs, ground_truth, confounders, groups, filename, mix_labels, indices in tqdm(dataloader):
            inputs, mix_labels = inputs.to(self.device), mix_labels.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            
            # The mix_labels here is ground truth labels
            losses = self.train_criterion(outputs, mix_labels)
            
            # Apply weighting for validation samples if specified
            if hasattr(args, 'validation_indices') and args.validation_indices:
                sample_weights = torch.ones_like(losses)
                for i, idx in enumerate(indices):
                    if idx.item() in args.validation_indices:
                        sample_weights[i] = current_weight
                
                weighted_losses = losses * sample_weights
                loss = weighted_losses.mean()
            else:
                loss = losses.mean()
            
            loss.backward()
            self.optimizer.step()
            
            _, preds = torch.max(outputs, 1)
            epoch_loss += loss.item()
            count += 1
            
            # Store individual sample data
            for i, idx in enumerate(indices):
                idx_val = idx.item()
                running_losses[idx_val] = losses[i].item()
                predictions_dict[idx_val] = preds[i].item()
                
                if idx_val not in batch_losses:
                    batch_losses[idx_val] = []
                batch_losses[idx_val].append((count, losses[i].item()))
        
        if self.scheduler:
            self.scheduler.step()
        
        return epoch_loss / count, running_losses, batch_losses, predictions_dict


def evaluate_model(dataloader, model, device):
    """
    Evaluate model with group-wise metrics.
    """
    model.eval()
    
    total_correct = 0
    total_samples = 0
    group_correct = defaultdict(int)
    group_total = defaultdict(int)
    
    with torch.no_grad():
        for inputs, _, _, groups, _, mix_labels, _ in dataloader:
            inputs, mix_labels = inputs.to(device), mix_labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct = (preds == mix_labels)
            
            total_correct += correct.sum().item()
            total_samples += mix_labels.size(0)
            
            # Group-wise statistics
            for i, group_id in enumerate(groups):
                group_total[group_id.item()] += 1
                if correct[i]:
                    group_correct[group_id.item()] += 1
    
    # Calculate accuracies
    overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    group_accuracies = {}
    for group_id in group_total:
        group_accuracies[group_id] = group_correct[group_id] / group_total[group_id] if group_total[group_id] > 0 else 0.0
    
    return overall_accuracy, group_accuracies


def evaluate_on_indices_with_custom_groups(dataloader, model, device, validation_indices, validation_groups):
    """
    Evaluate model on specific indices with custom group labels.
    """
    model.eval()
    
    group_correct = defaultdict(int)
    group_total = defaultdict(int)
    total_correct = 0
    total_samples = 0

    validation_indices_set = set(validation_indices)
    
    with torch.no_grad():
        # Iterate through the merged train+val dataloader
        for inputs, ground_truth, confounders, groups, filename, mix_labels, indices in dataloader:
            inputs, mix_labels = inputs.to(device), mix_labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            # Process each sample in the batch
            for i, idx in enumerate(indices):
                idx_val = idx.item()
                
                # Only evaluate samples that are in our SIEVE-selected validation set
                if idx_val in validation_indices_set and idx_val in validation_groups:
                    # Use the pseudo group label assigned by SIEVE
                    group_id = validation_groups[idx_val]
                    group_total[group_id] += 1
                    total_samples += 1
                    
                    # Check if prediction is correct
                    if preds[i] == mix_labels[i]:
                        group_correct[group_id] += 1
                        total_correct += 1

    overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    
    # Calculate group accuracies and find worst group accuracy
    group_accuracies = {}
    for group_id in group_total:
        if group_total[group_id] > 0:
            group_accuracies[group_id] = group_correct[group_id] / group_total[group_id]
        else:
            group_accuracies[group_id] = 0.0
    
    # Find worst group accuracy
    worst_group_accuracy = min(group_accuracies.values()) if group_accuracies else 0.0
    
    return overall_accuracy, worst_group_accuracy


def train_erm_with_sieve_validation(args, dataloaders, validation_indices, validation_groups, paths=None):
    """
    Train ERM model using SIEVE-constructed validation set for model selection.
    
    This implements Algorithm 2 from the paper.
    """
    device = args.device
    print("Training ERM with SIEVE validation set...")

    if not validation_indices or not validation_groups:
        raise ValueError("validation_indices and validation_groups cannot be empty")
    print(f"Using SIEVE validation set with {len(validation_indices)} samples")
    
    # Initialize model
    from models import ResNet50
    model = ResNet50().to(device)
    
    # Initialize optimizer
    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.wd
    )
    
    # Initialize trainer
    trainer = ERMTrainer(
        model=model,
        optimizer=optimizer,
        scheduler=None,
        device=device
    )
    
    # Training loop with SIEVE validation
    best_worst_group_acc = 0.0
    best_overall_acc = 0.0
    best_model_state = None
    
    print(f"Training for {args.erm_epochs} epochs...")
    
    for epoch in range(args.erm_epochs):
        # Train epoch
        epoch_loss, _, _, _ = trainer.train_epoch(dataloaders['train'], args, epoch)
        
        # Evaluate using SIEVE validation set (for model selection)
        val_overall_acc, val_worst_group_acc = evaluate_on_indices_with_custom_groups(
            dataloaders['train'], model, device, validation_indices, validation_groups
        )
        
        # Evaluate on test set (only for monitoring)
        test_overall_acc, test_group_accs = evaluate_model(dataloaders['test'], model, device)
        test_worst_group_acc = min(test_group_accs.values()) if test_group_accs else 0.0
        
        improved = False
        if val_worst_group_acc > best_worst_group_acc:
            improved = True
        elif val_worst_group_acc == best_worst_group_acc and val_overall_acc > best_overall_acc:
            improved = True
        
        if improved:
            best_worst_group_acc = val_worst_group_acc
            best_overall_acc = val_overall_acc
            best_model_state = copy.deepcopy(model.state_dict())
            print(f"New best model at epoch {epoch+1}: val_worst_group={val_worst_group_acc:.4f}, val_overall={val_overall_acc:.4f}")
        
        # Print progress
        if epoch % 10 == 0 or epoch == args.erm_epochs - 1:
            print(f"Epoch {epoch+1}/{args.erm_epochs}:")
            print(f"  Train Loss: {epoch_loss:.4f}")
            print(f"  Val Worst Group Acc: {val_worst_group_acc:.4f}")
            print(f"  Val Overall Acc: {val_overall_acc:.4f}")
            print(f"  Test Worst Group Acc: {test_worst_group_acc:.4f}")
            print(f"  Best Val Worst Group: {best_worst_group_acc:.4f}")
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model: val_worst_group={best_worst_group_acc:.4f}, test_overall={best_overall_acc:.4f}")
    
    # Save model if paths provided
    if paths:
        model_save_path = f"{paths['save_model']}/sieve_erm_model_seed{args.seed}.model"
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")
    
    return {
        'model': model,
        'best_val_worst_group_acc': best_worst_group_acc,
        'best_val_overall_acc': best_overall_acc
    }
    

