"""
Main Evaluation Script for Human-Prior Correction

This script reproduces the main experimental results from the HPC paper.
It loads pretrained models, applies HPC with different priors, compares
against baseline calibration methods, and generates comprehensive results.

Usage:
    python evaluate_hpc.py --dataset cifar10h --model_path ./models/resnet50_cifar10.pth
    python evaluate_hpc.py --dataset cifar100 --model_path ./models/resnet50_cifar100.pth
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import argparse
import os
import json
from datetime import datetime
from typing import Dict, List, Tuple, Optional

# Import our HPC modules
from hpc_core import HumanPriorCorrection, AdaptiveAlpha
from proxy_priors import ProxyPriorConstructor
from cifar10h_utils import CIFAR10HUtils
from evaluation_metrics import evaluate_calibration_comprehensive, CalibrationMetrics
from baseline_calibration import fit_all_baselines
from data_loaders import create_data_loader


class ModelEvaluator:
    """
    Comprehensive model evaluation with HPC and baseline methods.
    """
    
    def __init__(
        self,
        model: nn.Module,
        device: torch.device,
        num_classes: int,
        results_dir: str = "./results"
    ):
        self.model = model.to(device)
        self.device = device
        self.num_classes = num_classes
        self.results_dir = results_dir
        
        # Create results directory
        os.makedirs(results_dir, exist_ok=True)
        
        # Initialize components
        self.hpc = HumanPriorCorrection(num_classes=num_classes).to(device)
        self.proxy_constructor = ProxyPriorConstructor()
        if num_classes == 10:
            self.cifar10h_utils = CIFAR10HUtils()
        
        # Storage for results
        self.results = {}
        
    def extract_model_predictions(
        self, 
        data_loader: torch.utils.data.DataLoader,
        max_batches: Optional[int] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Extract model logits and probabilities on dataset.
        
        Args:
            data_loader: DataLoader for the dataset
            max_batches: Optional limit on number of batches to process
            
        Returns:
            (logits, probabilities, true_labels, human_distributions)
        """
        self.model.eval()
        
        all_logits = []
        all_probs = []
        all_targets = []
        all_human_dists = []
        
        with torch.no_grad():
            for batch_idx, (images, targets, human_dists) in enumerate(data_loader):
                if max_batches and batch_idx >= max_batches:
                    break
                    
                images = images.to(self.device)
                targets = targets.to(self.device)
                human_dists = human_dists.to(self.device)
                
                # Forward pass
                logits = self.model(images)
                probs = F.softmax(logits, dim=1)
                
                # Store results
                all_logits.append(logits.cpu())
                all_probs.append(probs.cpu())
                all_targets.append(targets.cpu())
                all_human_dists.append(human_dists.cpu())
        
        # Concatenate all batches
        logits = torch.cat(all_logits, dim=0)
        probabilities = torch.cat(all_probs, dim=0)
        targets = torch.cat(all_targets, dim=0)
        human_distributions = torch.cat(all_human_dists, dim=0)
        
        print(f"Extracted predictions for {len(targets)} samples")
        return logits, probabilities, targets, human_distributions
    
    def evaluate_hpc_variants(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        human_distributions: torch.Tensor,
        dataset_name: str
    ) -> Dict[str, Dict[str, float]]:
        """
        Evaluate HPC with different prior types.
        
        Args:
            logits: Model logits (N, K)
            targets: True labels (N,)
            human_distributions: Human label distributions (N, K) 
            dataset_name: Name of dataset for prior construction
            
        Returns:
            Dictionary mapping HPC variant names to metrics
        """
        results = {}
        
        # 1. HPC with empirical CIFAR-10H prior (if applicable)
        if self.num_classes == 10:
            print("Evaluating HPC with empirical CIFAR-10H prior...")
            
            # Get empirical confusion matrix
            empirical_confusion = self.cifar10h_utils.get_empirical_confusion_matrix()
            
            # Apply HPC
            hpc_probs = self.hpc.apply_correction(
                logits, empirical_confusion, alpha=0.3, temperature=1.0
            )
            
            # Evaluate
            metrics = evaluate_calibration_comprehensive(
                hpc_probs, targets, human_distributions,
                method_name="HPC_Empirical",
                save_plots=True,
                save_dir=self.results_dir
            )
            results['HPC_Empirical'] = metrics
        
        # 2. HPC with CLIP proxy prior
        print("Evaluating HPC with CLIP proxy prior...")
        try:
            if dataset_name.lower() in ['cifar10', 'cifar10h']:
                clip_confusion = self.proxy_constructor.create_cifar10_clip_prior()
            elif dataset_name.lower() == 'cifar100':
                clip_confusion = self.proxy_constructor.create_cifar100_clip_prior()
            else:
                # Use generic CLIP prior
                class_names = [f"class_{i}" for i in range(self.num_classes)]
                clip_confusion = self.proxy_constructor.create_clip_prior(class_names)
            
            hpc_probs = self.hpc.apply_correction(
                logits, clip_confusion, alpha=0.25, temperature=1.0
            )
            
            metrics = evaluate_calibration_comprehensive(
                hpc_probs, targets, human_distributions,
                method_name="HPC_CLIP",
                save_plots=True,
                save_dir=self.results_dir
            )
            results['HPC_CLIP'] = metrics
            
        except Exception as e:
            print(f"CLIP prior evaluation failed: {e}")
        
        # 3. HPC with adaptive alpha
        print("Evaluating HPC with adaptive alpha...")
        try:
            if self.num_classes == 10:
                confusion_matrix = empirical_confusion
            else:
                confusion_matrix = clip_confusion
            
            # Initialize adaptive alpha module
            adaptive_alpha = AdaptiveAlpha(
                input_dim=self.num_classes, 
                hidden_dim=64
            ).to(self.device)
            
            # Apply HPC with adaptive alpha
            hpc_probs = self.hpc.apply_correction_adaptive(
                logits.to(self.device), 
                confusion_matrix, 
                adaptive_alpha,
                base_alpha=0.3,
                temperature=1.0
            )
            
            metrics = evaluate_calibration_comprehensive(
                hpc_probs.cpu(), targets, human_distributions,
                method_name="HPC_Adaptive",
                save_plots=True, 
                save_dir=self.results_dir
            )
            results['HPC_Adaptive'] = metrics
            
        except Exception as e:
            print(f"Adaptive alpha evaluation failed: {e}")
        
        return results
    
    def evaluate_baseline_methods(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        human_distributions: torch.Tensor,
        val_split: float = 0.2
    ) -> Dict[str, Dict[str, float]]:
        """
        Evaluate baseline calibration methods.
        
        Args:
            logits: Model logits (N, K)
            targets: True labels (N,)
            human_distributions: Human label distributions (N, K)
            val_split: Fraction of data to use for calibration fitting
            
        Returns:
            Dictionary mapping baseline method names to metrics
        """
        # Split data for calibration fitting
        n_samples = len(logits)
        n_val = int(n_samples * val_split)
        
        # Random split
        indices = torch.randperm(n_samples)
        val_indices = indices[:n_val]
        test_indices = indices[n_val:]
        
        # Validation data for fitting
        val_logits = logits[val_indices]
        val_probs = F.softmax(val_logits, dim=1)
        val_targets = targets[val_indices]
        
        # Test data for evaluation
        test_logits = logits[test_indices]
        test_probs = F.softmax(test_logits, dim=1)
        test_targets = targets[test_indices] 
        test_human_dists = human_distributions[test_indices]
        
        print(f"Fitting baselines on {len(val_logits)} samples...")
        print(f"Evaluating on {len(test_logits)} samples...")
        
        # Fit baseline calibrators
        calibrators = fit_all_baselines(val_logits, val_probs, val_targets)
        
        # Evaluate each method
        results = {}
        
        # Uncalibrated baseline
        metrics = evaluate_calibration_comprehensive(
            test_probs, test_targets, test_human_dists,
            method_name="Uncalibrated",
            save_plots=True,
            save_dir=self.results_dir
        )
        results['Uncalibrated'] = metrics
        
        # Calibrated methods
        for method_name, calibrator in calibrators.items():
            print(f"Evaluating {method_name}...")
            
            try:
                # Apply calibration
                if hasattr(calibrator, 'forward'):
                    # PyTorch modules (temperature scaling, etc.)
                    with torch.no_grad():
                        calibrated_probs = calibrator(test_logits)
                else:
                    # Scikit-learn style (isotonic regression, etc.)
                    calibrated_probs = calibrator.predict_proba(test_probs)
                
                # Evaluate
                metrics = evaluate_calibration_comprehensive(
                    calibrated_probs, test_targets, test_human_dists,
                    method_name=method_name,
                    save_plots=True,
                    save_dir=self.results_dir
                )
                results[method_name] = metrics
                
            except Exception as e:
                print(f"Error evaluating {method_name}: {e}")
                continue
        
        return results
    
    def generate_results_table(
        self,
        hpc_results: Dict[str, Dict[str, float]],
        baseline_results: Dict[str, Dict[str, float]]
    ) -> str:
        """
        Generate formatted results table.
        
        Args:
            hpc_results: Results from HPC variants
            baseline_results: Results from baseline methods
            
        Returns:
            Formatted table string
        """
        # Combine all results
        all_results = {**baseline_results, **hpc_results}
        
        # Key metrics to include in table
        key_metrics = ['accuracy', 'ece', 'nll_true', 'nll_human', 'aurc']
        
        # Create table header
        header = "Method".ljust(20)
        for metric in key_metrics:
            header += f"{metric.upper()}".ljust(12)
        
        table = header + "\n" + "="*80 + "\n"
        
        # Add rows for each method
        for method_name, metrics in all_results.items():
            row = method_name.ljust(20)
            for metric in key_metrics:
                if metric in metrics:
                    value = metrics[metric]
                    row += f"{value:.4f}".ljust(12)
                else:
                    row += "N/A".ljust(12)
            table += row + "\n"
        
        return table
    
    def save_results(
        self,
        hpc_results: Dict[str, Dict[str, float]],
        baseline_results: Dict[str, Dict[str, float]],
        dataset_name: str,
        model_name: str
    ):
        """Save results to JSON and text files."""
        # Combine results
        all_results = {
            'hpc_methods': hpc_results,
            'baseline_methods': baseline_results,
            'metadata': {
                'dataset': dataset_name,
                'model': model_name,
                'num_classes': self.num_classes,
                'timestamp': datetime.now().isoformat()
            }
        }
        
        # Save JSON
        json_path = os.path.join(self.results_dir, f"results_{dataset_name}_{model_name}.json")
        with open(json_path, 'w') as f:
            json.dump(all_results, f, indent=2)
        
        # Save formatted table
        table = self.generate_results_table(hpc_results, baseline_results)
        table_path = os.path.join(self.results_dir, f"results_table_{dataset_name}_{model_name}.txt")
        with open(table_path, 'w') as f:
            f.write(f"Results for {dataset_name} using {model_name}\n")
            f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            f.write(table)
        
        print(f"Results saved to {json_path} and {table_path}")


def load_pretrained_model(model_path: str, num_classes: int, device: torch.device) -> nn.Module:
    """
    Load pretrained model from checkpoint.
    
    Args:
        model_path: Path to model checkpoint
        num_classes: Number of output classes
        device: Device to load model on
        
    Returns:
        Loaded PyTorch model
    """
    if not os.path.exists(model_path):
        print(f"Model path {model_path} not found. Creating dummy ResNet50...")
        # Create dummy model for testing
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model.to(device)
    
    # Load actual checkpoint
    try:
        checkpoint = torch.load(model_path, map_location=device)
        
        # Extract model architecture info from checkpoint if available
        if 'arch' in checkpoint:
            arch = checkpoint['arch']
        else:
            arch = 'resnet50'  # Default assumption
        
        # Create model
        if arch == 'resnet50':
            model = models.resnet50(pretrained=False)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        else:
            raise ValueError(f"Unknown architecture: {arch}")
        
        # Load state dict
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        print(f"Loaded model from {model_path}")
        return model.to(device)
        
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Creating dummy model for testing...")
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model.to(device)


def main():
    parser = argparse.ArgumentParser(description='Evaluate HPC methods')
    parser.add_argument('--dataset', type=str, default='cifar10h',
                       choices=['cifar10h', 'cifar10', 'cifar100', 'imagenet'],
                       help='Dataset to evaluate on')
    parser.add_argument('--model_path', type=str, default=None,
                       help='Path to pretrained model checkpoint')
    parser.add_argument('--data_root', type=str, default='./data',
                       help='Root directory for datasets')
    parser.add_argument('--human_annotations', type=str, default=None,
                       help='Path to human annotations file')
    parser.add_argument('--batch_size', type=int, default=128,
                       help='Batch size for evaluation')
    parser.add_argument('--max_batches', type=int, default=None,
                       help='Maximum number of batches to process (for testing)')
    parser.add_argument('--results_dir', type=str, default='./results',
                       help='Directory to save results')
    parser.add_argument('--device', type=str, default='auto',
                       help='Device to use (auto, cpu, cuda)')
    
    args = parser.parse_args()
    
    # Setup device
    if args.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)
    
    print(f"Using device: {device}")
    
    # Determine number of classes
    if args.dataset in ['cifar10h', 'cifar10']:
        num_classes = 10
    elif args.dataset == 'cifar100':
        num_classes = 100
    elif args.dataset == 'imagenet':
        num_classes = 1000
    
    # Load model
    if args.model_path:
        model = load_pretrained_model(args.model_path, num_classes, device)
        model_name = os.path.basename(args.model_path).split('.')[0]
    else:
        print("No model path provided. Creating dummy ResNet50 for testing...")
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model = model.to(device)
        model_name = "resnet50_dummy"
    
    # Create data loader
    data_loader = create_data_loader(
        args.dataset,
        split='test',
        batch_size=args.batch_size,
        data_root=args.data_root,
        human_annotations_path=args.human_annotations
    )
    
    # Initialize evaluator
    evaluator = ModelEvaluator(
        model=model,
        device=device,
        num_classes=num_classes,
        results_dir=args.results_dir
    )
    
    # Extract model predictions
    print(f"Extracting model predictions on {args.dataset}...")
    logits, probabilities, targets, human_distributions = evaluator.extract_model_predictions(
        data_loader, max_batches=args.max_batches
    )
    
    # Evaluate HPC variants
    print("\n" + "="*50)
    print("EVALUATING HPC METHODS")
    print("="*50)
    hpc_results = evaluator.evaluate_hpc_variants(
        logits, targets, human_distributions, args.dataset
    )
    
    # Evaluate baseline methods
    print("\n" + "="*50)
    print("EVALUATING BASELINE METHODS")  
    print("="*50)
    baseline_results = evaluator.evaluate_baseline_methods(
        logits, targets, human_distributions
    )
    
    # Generate and display results
    print("\n" + "="*50)
    print("FINAL RESULTS")
    print("="*50)
    table = evaluator.generate_results_table(hpc_results, baseline_results)
    print(table)
    
    # Save results
    evaluator.save_results(hpc_results, baseline_results, args.dataset, model_name)
    
    print(f"\nEvaluation completed. Results saved to {args.results_dir}")


if __name__ == "__main__":
    main()
