####################################################################
# SIEVE: Spurious-aware Iterative Validation Example Selection
# 
# This file contains the core implementation of the SIEVE algorithm
# as described in the paper. The algorithm iteratively identifies 
# confusing samples via feature-space similarity and classifies 
# their group membership based on loss dynamics.
#
# Reference: Algorithm 1 (SIEVE) and Section 2 in the main paper
####################################################################

"""
Example usage:

# Initialize SIEVE selector
selector = SIEVESelector(dataset_name='waterbirds',num_iterations=20, n_confusing=200, device='cuda')

# Run SIEVE algorithm
validation_data = selector.run_iterative_validation_selection(
    args=args, 
    initial_model=model, 
    dataloaders=dataloaders, 
    paths=paths
)

# Use selected validation set
selected_indices = validation_data['indices']
group_labels = validation_data['groups']
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Set, Tuple, Any
from collections import Counter
import copy
import os
import json
from tqdm import tqdm
from erm_trainer import ERMTrainer
from data_collector import DataCollector
from models import ResNet50
import torch.optim as optim


def precompute_distances(dataloader, fixed_model, save_path, device):
    """
    Precompute and save feature-space cross-class distances for each training sample.

    This is the helper function for Step 1 of the SIEVE pipeline: computing distances 
    in the feature space between each sample and its nearest neighbor from the 
    opposite class.


    Returns:
        distances: A dictionary mapping sample index → minimum cross-class distance.

    Notes:
        - The output is saved as a dictionary with keys:
            distances: {sample_index: distance}
            total_samples: total number of samples processed
        - The `mix_labels' is ground truth labels

    See Section 2.2 in the paper for details.
    """
    distances = {}
    fixed_model.eval()
    
    all_features = []
    all_labels = []
    all_indices = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader):
            inputs, _, _, _, _, mix_labels, indices = batch
            inputs = inputs.to(device)
            features = fixed_model.get_representation(inputs)
            all_features.append(features)
            all_labels.extend(mix_labels.numpy())
            all_indices.extend(indices.numpy())
    
    all_features = torch.cat(all_features, dim=0)
    all_labels = np.array(all_labels)
    
    for i, idx in enumerate(all_indices):
        current_label = all_labels[i]
        current_feature = all_features[i:i+1]
        
        diff_class_mask = all_labels != current_label
        diff_class_features = all_features[diff_class_mask]
        
        if len(diff_class_features) > 0:
            dist = torch.cdist(current_feature, diff_class_features)
            min_dist = dist.min().item()
        else:
            min_dist = float('inf')
        
        distances[idx] = min_dist
    
    save_dict = {
        "distances": {int(k): v for k, v in distances.items()},
        "total_samples": len(distances)
    }
    
    torch.save(save_dict, save_path)
    return distances


def assign_group_by_spuriousness(dataset_name, ground_truth, is_spurious):
    """
    Assign group labels based on whether the sample is spurious or not.
    
    The grouping strategy depends on the specific dataset being used,
    you can extend the logic following the same pattern for any other dataset.
    Below, we use the Waterbirds dataset as an example:

        In Waterbirds:
        - `y` is the ground-truth label for bird type:
            - y = 0 → landbird
            - y = 1 → waterbird
        - `c` is the spurious attribute (background type):
            - c = 0 → land background
            - c = 1 → water background
        - Group is defined following the convention:
            - Group 0: y = 0 and c = 0 → spurious (landbird on land)
            - Group 1: y = 0 and c = 1 → non-spurious (landbird on water)
            - Group 2: y = 1 and c = 0 → non-spurious (waterbird on land)
            - Group 3: y = 1 and c = 1 → spurious (waterbird on water)

        Therefore,
        - If the sample is spurious:
            - y = 0 → Group 0
            - y = 1 → Group 3
        - If the sample is non-spurious:
            - y = 0 → Group 1
            - y = 1 → Group 2
    """
    ground_truth = int(ground_truth)  
    
    if dataset_name in ['waterbirds']:
        if is_spurious:
            group = 0 if ground_truth == 0 else 3
        else:
            group = 1 if ground_truth == 0 else 2

    return group


def remove_validation_from_training(dataloaders, validation_indices):
    """
    Remove selected validation examples from the training and distance dataloaders.

    This function updates the training dataloader by excluding the indices of
    validation samples selected during the current SIEVE iteration,
    ensuring that no validation sample is used again in the next round.

    Returns:
        updated_loaders: A new dictionary of dataloaders where the selected validation
                         examples have been removed from 'train' and 'distance'.
    """
    from torch.utils.data import DataLoader, Subset
    
    updated_loaders = dataloaders.copy()
    updated_loaders['original_train'] = dataloaders['train']
    
    validation_indices_set = set(validation_indices)

    original_train_dataset = dataloaders['train'].dataset.dataset
    train_indices = dataloaders['train'].dataset.indices

    train_global_to_local = {global_idx: local_idx for local_idx, global_idx in enumerate(train_indices)}
    
    train_indices_filtered = [idx for idx in train_indices if idx not in validation_indices_set]
    new_train_dataset = Subset(original_train_dataset, train_indices_filtered)

    updated_loaders['train'] = DataLoader(
        new_train_dataset,
        batch_size=dataloaders['train'].batch_size,
        shuffle=True,
        num_workers=getattr(dataloaders['train'], 'num_workers', 0),
        pin_memory=getattr(dataloaders['train'], 'pin_memory', False)
    )

    if 'distance' in dataloaders:
        distance_dataset = dataloaders['distance'].dataset
        if isinstance(distance_dataset, Subset):
            distance_indices = distance_dataset.indices
            distance_indices_filtered = [idx for idx in distance_indices if idx not in validation_indices_set]
            new_distance_dataset = Subset(distance_dataset.dataset, distance_indices_filtered)
            
            updated_loaders['distance'] = DataLoader(
                new_distance_dataset,
                batch_size=dataloaders['distance'].batch_size,
                shuffle=False, 
                num_workers=getattr(dataloaders['distance'], 'num_workers', 0),
                pin_memory=getattr(dataloaders['distance'], 'pin_memory', False)
            )
    
    print(f"Removed {len(validation_indices_set)} samples from training dataloaders")
    return updated_loaders


def save_validation_indices(validation_data, metadata, save_path):
    """
    Save validation indices and metadata to JSON file
    """
    serializable_data = {}
    
    if isinstance(validation_data, dict):
        if "indices" in validation_data:
            serializable_data["validation_indices"] = list(validation_data["indices"])
            if "groups" in validation_data:
                serializable_data["group_labels"] = {str(k): int(v) for k, v in validation_data["groups"].items()}
            if "confusing_spurious" in validation_data:
                serializable_data["confusing_spurious"] = list(validation_data["confusing_spurious"])
            if "confusing_non_spurious" in validation_data:
                serializable_data["confusing_non_spurious"] = list(validation_data["confusing_non_spurious"])
            if "accuracy" in validation_data:
                serializable_data["accuracy"] = validation_data["accuracy"]
        elif "all_indices" in validation_data:
            serializable_data["validation_indices"] = list(validation_data["all_indices"])
            if "groups" in validation_data:
                serializable_data["group_labels"] = {str(k): int(v) for k, v in validation_data["groups"].items()}
    else:
        serializable_data["validation_indices"] = list(validation_data)
    
    serializable_data["selection_criteria"] = metadata
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    with open(save_path, 'w') as f:
        json.dump(serializable_data, f, indent=2)


def load_sample_data(data_path):
    """Load sample data from JSON file"""
    with open(data_path, 'r') as f:
        return json.load(f)


class SIEVESelector:
    """
    SIEVE algorithm implementation based on the actual project code
    
    This class expects users to provide:
    1. Precomputed cross-class distances for all samples
    2. Training data containing loss changes over epochs
    """
    
    def __init__(self, dataset_name, num_iterations=20, n_confusing=200, device='cuda'):
        self.dataset_name = dataset_name
        self.num_iterations = num_iterations
        self.n_confusing = n_confusing
        self.device = device
    
    def get_confusing_nonconfusing_samples(self, distances, num_samples):
        """
        Select confusing and non-confusing samples based on distances
        """
        sorted_distances = sorted(distances.items(), key=lambda x: x[1])
        
        confusing = set([idx for idx, _ in sorted_distances[:num_samples]])
        nonconfusing = set([idx for idx, _ in sorted_distances[-num_samples:]])
        
        return confusing, nonconfusing
    
    def classify_confusing_samples(self, data, n_confusing=200, 
                                 high_threshold_spurious_ratio=0.8,
                                 high_threshold_non_spurious_ratio=0.8,
                                 low_threshold_spurious_ratio=0.1,
                                 low_threshold_non_spurious_ratio=0.1,
                                 select_confusing_spurious_by="decrease",
                                 select_confusing_non_spurious_by="increase",
                                 start_epoch=0, end_epoch=1):
        """
        Classify confusing samples into spurious and non-spurious categories 
        based on their loss change dynamics.

        Given a set of pre-identified confusing samples (based on feature-space distances),
        we compare their training losses before and after one epoch of training and use the
        change in loss (ΔL) as a signal to classify them:

        Spurious examples are expected to experience fast loss decrease. 
        Thus top τ% of confusing examples with the largest loss decrease are labeled as spurious.
        Non-spurious examples are expected to show large loss increase.
        Thus top τ% of confusing examples with the largest loss increase are labeled as non-spurious.
        (τ = 0.8 by default)

        Key Parameters:
            - n_confusing: Number of confusing examples to consider.
            - high_threshold_spurious_ratio: Percentile threshold (τ) for selecting spurious examples via loss decrease.
            - high_threshold_non_spurious_ratio: Percentile threshold (τ) for selecting non-spurious examples via loss increase.
            - select_confusing_spurious_by: Whether to select spurious examples based on 'decrease' or 'increase' in loss.
            - select_confusing_non_spurious_by: Whether to select non-spurious examples based on 'increase' or 'decrease'.
            - start_epoch: Index of the epoch to compute loss before training.
            - end_epoch: Index of the epoch to compute loss after training.

        """
        from analysis_utils import is_spurious
        
        # Get distances from data if available
        if 'distances' in data:
            distances = {int(k): v for k, v in data['distances'].items()}
        else:
            distances = {}
        
        # Get confusing samples based on distances
        confusing, nonconfusing = self.get_confusing_nonconfusing_samples(distances, n_confusing)
        
        # Analyze loss changes for confusing samples
        loss_key = 'losses'
        sample_ids = list(confusing)
        
        increasing_samples = []
        increasing_losses = []
        decreasing_samples = []
        decreasing_losses = []
        
        for idx in sample_ids:
            if str(idx) not in data[loss_key]:
                continue
                
            sample_loss = data[loss_key][str(idx)]
            if len(sample_loss) <= end_epoch:
                continue
                
            loss_change = sample_loss[end_epoch] - sample_loss[start_epoch]
            
            if loss_change > 0:  # Loss increased
                increasing_samples.append(idx)
                increasing_losses.append(abs(loss_change))
            elif loss_change < 0:  # Loss decreased
                decreasing_samples.append(idx)
                decreasing_losses.append(abs(loss_change))
        
        # Select spurious/non-spurious based on loss dynamics using percentile thresholds
        if select_confusing_spurious_by == "decrease" and decreasing_losses:
            high_threshold_spurious = np.percentile(decreasing_losses, high_threshold_spurious_ratio * 100)
            confusing_spurious = [idx for idx, loss in zip(decreasing_samples, decreasing_losses) 
                                if loss > high_threshold_spurious]
        elif select_confusing_spurious_by == "increase" and increasing_losses:
            low_threshold_spurious = np.percentile(increasing_losses, low_threshold_spurious_ratio * 100)
            confusing_spurious = [idx for idx, loss in zip(increasing_samples, increasing_losses) 
                                if loss < low_threshold_spurious]
        else:
            confusing_spurious = []
            
        if select_confusing_non_spurious_by == "increase" and increasing_losses:
            high_threshold_non_spurious = np.percentile(increasing_losses, high_threshold_non_spurious_ratio * 100)
            confusing_non_spurious = [idx for idx, loss in zip(increasing_samples, increasing_losses) 
                                    if loss > high_threshold_non_spurious]
        elif select_confusing_non_spurious_by == "decrease" and decreasing_losses:
            low_threshold_non_spurious = np.percentile(decreasing_losses, low_threshold_non_spurious_ratio * 100)
            confusing_non_spurious = [idx for idx, loss in zip(decreasing_samples, decreasing_losses) 
                                    if loss < low_threshold_non_spurious]
        else:
            confusing_non_spurious = []
        
        # Calculate accuracy against ground truth spuriousness
        correct = 0
        total = len(confusing_spurious) + len(confusing_non_spurious)
        
        for idx in confusing_spurious:
            if str(idx) in data['ground_truths'] and str(idx) in data['confounders']:
                gt = data['ground_truths'][str(idx)]
                confounder = data['confounders'][str(idx)]
                if is_spurious(gt, confounder, self.dataset_name):
                    correct += 1
        
        for idx in confusing_non_spurious:
            if str(idx) in data['ground_truths'] and str(idx) in data['confounders']:
                gt = data['ground_truths'][str(idx)]
                confounder = data['confounders'][str(idx)]
                if not is_spurious(gt, confounder, self.dataset_name):
                    correct += 1
        
        accuracy = correct / total if total > 0 else 0.0
        
        return {
            "confusing_spurious": set(confusing_spurious),
            "confusing_non_spurious": set(confusing_non_spurious),
            "all_indices": set(confusing_spurious).union(set(confusing_non_spurious)),
            "accuracy": accuracy
        }

    def select_validation_samples(self, data, dataset_name, 
                             running_loss=False, 
                             high_threshold_spurious_ratio=0.8,
                             high_threshold_non_spurious_ratio=0.8,
                             low_threshold_spurious_ratio=0.1,
                             low_threshold_non_spurious_ratio=0.1,
                             start_epoch=0, end_epoch=1, 
                             use_final_loss=False,
                             n_confusing=200,
                             select_confusing_spurious_by="decrease",
                             select_confusing_non_spurious_by="increase"):
        """
        Select validation examples and assign pseudo group labels based on 
        their loss dynamics.

        This function wraps `classify_confusing_samples` and adds group assignment.
        It serves as a key component in SIEVE's iterative validation construction pipeline.

        Steps:
            1. Identify confusing examples using cross-class distances.
            2. Classify them into spurious vs. non-spurious based on loss change (ΔL).
            3. Assign pseudo group labels based on the predicted spuriousness and true label.

        Returns:
            A dictionary containing:
                - 'all_indices': All selected validation sample indices.
                - 'confusing_spurious': Subset predicted as spurious.
                - 'confusing_non_spurious': Subset predicted as non-spurious.
                - 'accuracy': Accuracy of pseudo spuriousness prediction (for analysis only).
                - 'groups': Pseudo group label for each selected example.
        """
        # Call the core classification algorithm from this class
        result = self.classify_confusing_samples(
            data=data,
            n_confusing=n_confusing,
            high_threshold_spurious_ratio=high_threshold_spurious_ratio,
            high_threshold_non_spurious_ratio=high_threshold_non_spurious_ratio,
            low_threshold_spurious_ratio=low_threshold_spurious_ratio,
            low_threshold_non_spurious_ratio=low_threshold_non_spurious_ratio,
            select_confusing_spurious_by=select_confusing_spurious_by,
            select_confusing_non_spurious_by=select_confusing_non_spurious_by,
            start_epoch=start_epoch,
            end_epoch=end_epoch
        )
        
        # Extract results
        accuracy = result["accuracy"]
        validation_indices = result["all_indices"]
        
        print(f" {len(validation_indices)} examples in total have been chosen:")
        print(f"  - {len(result['confusing_spurious'])} confusing spurious examples")
        print(f"  - {len(result['confusing_non_spurious'])} confusing non-spurious examples")
        print(f"  - selection accuracy: {accuracy:.4f}")
        
        # Assign group labels
        print("\nAssigning group labels for selected examples...")
        validation_groups = {}
        
        # For confusing_spurious examples (is spurious)
        for idx in result["confusing_spurious"]:
            ground_truth = data['ground_truths'][str(idx)]
            group = assign_group_by_spuriousness(dataset_name, ground_truth, True)
            validation_groups[idx] = group
        
        # For confusing_non_spurious examples (not spurious)
        for idx in result["confusing_non_spurious"]:
            ground_truth = data['ground_truths'][str(idx)]
            group = assign_group_by_spuriousness(dataset_name, ground_truth, False)
            validation_groups[idx] = group
        
        # Print group distribution
        group_counts = Counter(validation_groups.values())
        print("\nNumber of examples in each group:")
        for group in sorted(group_counts):
            print(f"  - Group {group}: {group_counts[group]} examples")
        
        if len(group_counts) < 4:
            print("Note! Not each group has examples in it!")
        
        return {
            "all_indices": validation_indices,
            "confusing_spurious": result["confusing_spurious"],
            "confusing_non_spurious": result["confusing_non_spurious"],
            "accuracy": accuracy,
            "groups": validation_groups  
        }
    

    def run_iterative_validation_selection(self, args, initial_model, dataloaders, paths, epochs=1):
        """
        Run the full SIEVE pipeline to construct a pseudo-labeled validation set via
        iterative selection based on feature-space distances and loss dynamics.

        This function implements the main loop of Algorithm 1 (SIEVE) from the paper.

        At a high level, each iteration consists of:
            1. Training a model for one epoch.
            2. Collecting per-sample training losses before and after training.
            3. Computing loss change (ΔL) for each confusing sample.
            4. Selecting confusing samples and classifying them into spurious / non-spurious
            based on their loss dynamics.
            5. Assigning pseudo group labels to selected examples.
            6. Removing selected samples from the training set and continuing to the next iteration.

        Returns:
            validation_data: A dictionary with:
                - "indices": Set of all selected validation sample indices across all iterations
                - "groups": Dictionary mapping each selected sample index to its assigned group label

        For further details, refer to:
            - Algorithm 1 (SIEVE) in the main paper
            - Section 2.2: Detailed discussion about SIEVE
        """
        print(f"Starting SIEVE with {self.num_iterations} iterations...")
        
        all_validation_indices = set()
        all_validation_groups = {}
        current_dataloaders = copy.deepcopy(dataloaders)
        
        # Precompute distances once at the beginning for further use
        print("Computing cross-class distances...")
        distances_path = f"sieve_distances_temp.pt"
        original_distances = precompute_distances(
            current_dataloaders['distance'], 
            initial_model, 
            distances_path,
            self.device 
        )
        current_distances = copy.deepcopy(original_distances)
        num_iterations = self.num_iterations
        # Start the iteration
        for iteration in range(num_iterations):
            print(f"\n----- Iteration {iteration+1}/{num_iterations} -----")
            
            model = ResNet50().to(self.device)
        
            current_args = copy.deepcopy(args)
            current_args.distances = current_distances
            
            # Run the standard validation selection process
            optimizer = optim.SGD(
                model.parameters(),
                lr=args.lr,
                momentum=0.9,
                weight_decay=args.wd
            )
            
            scheduler = None
            
            # Initialize a trainer for training ERM later
            trainer = ERMTrainer(
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                device=self.device
            )
            
            # Initialize a data_collector to collect data
            data_collector = DataCollector(
                dataset=current_dataloaders['train'].dataset,
                distances=current_args.distances 
            )
            
            print(f"Collecting data in epoch 0 for iteration {iteration+1}...")
            data_collector.collect_epoch_data(0, model, current_dataloaders['train'], trainer.analysis_criterion, self.device)
            
            print(f"Training started for iteration {iteration+1}...")
            for epoch in range(1, args.epochs + 1):
                print(f"Training epoch {epoch}/{args.epochs}")
                
                epoch_loss, running_losses, batch_losses, predictions = trainer.train_epoch(current_dataloaders['train'], current_args, epoch, mode="selection")
                print(f"Epoch {epoch} completed, loss: {epoch_loss:.5f}")
                
                print(f"Updating collected data for epoch {epoch}...")
                data_collector.update_training_data(epoch, running_losses, batch_losses, predictions)
                data_collector.collect_epoch_data(epoch, model, current_dataloaders['train'], trainer.analysis_criterion, self.device)
            
            # Save iteration data
            data_path = f"{paths['selection_data']}/collected_data_iteration_{iteration+1}_seed{args.seed}.json"
            data_collector.save(data_path)
            print(f"Collected data for iteration {iteration+1} saved to {data_path}")
            
            # Select validation examples for this iteration
            collected_data = load_sample_data(data_path)
            
            validation_indices_result = self.select_validation_samples(
                data=collected_data,
                dataset_name=args.dataset,
                running_loss=False,
                start_epoch=0,
                end_epoch=args.epochs,
                n_confusing=args.n_confusing,
                select_confusing_spurious_by=args.select_confusing_spurious_by,
                select_confusing_non_spurious_by=args.select_confusing_non_spurious_by
            )
            
            selected_indices = validation_indices_result["all_indices"]
            print(f"Selected {len(selected_indices)} new validation samples in iteration {iteration+1}")
            
            all_validation_indices.update(validation_indices_result["all_indices"])
            for idx in validation_indices_result["all_indices"]:
                if idx in validation_indices_result["groups"]:
                    all_validation_groups[idx] = validation_indices_result["groups"][idx]
            
            # Update dataloaders to remove selected samples for next iteration
            current_dataloaders = remove_validation_from_training(
                current_dataloaders, 
                validation_indices_result["all_indices"]
            )
            print(f"Updated dataloaders for next iteration, removed {selected_indices} samples")

            # Update the distance disctionary by removing the selected examples
            current_distances = {k: v for k, v in current_distances.items() if k not in selected_indices}
        
        
        # Prepare final result
        validation_data = {
            "indices": all_validation_indices,
            "groups": all_validation_groups
        }
        
        # Save the final result
        save_path = f"{paths['selected_examples']}/validation_indices_iterative_{num_iterations}_seed{args.seed}.json"

        save_validation_indices(validation_data, {
            "method": "iterative_loss_pattern_analysis",
            "dataset": args.dataset,
            "iterations": num_iterations,
            "seed": args.seed
        }, save_path)
        
        print(f"Total selected validation indices: {len(all_validation_indices)}")
        print(f"With group labels: {len(all_validation_groups)}")
        print(f"Final selection saved to {save_path}")
        print("==================== Iterative validation selection process completed ====================\n")
        
        return validation_data

