"""Main runner for ARCOS active learning experiments."""

import os
import time
import argparse
import pandas as pd
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from datetime import datetime

from ..utils.config import load_config
from ..utils.seed import set_seed, set_deterministic_mode
from ..utils.checkpoint import CheckpointManager
from ..data.datasets import get_dataset_class, get_transforms, create_dataloader
from ..models.backbone import get_backbone, freeze_backbone_layers, unfreeze_all_layers
from ..models.train import Trainer, FineTuner, evaluate_model
from ..trace.metrics import compute_trace_metrics
from ..acquisition.policies import get_acquisition_policy
from ..acquisition.feature_cache import FeatureCache
from ..viz.plots import ARCOSPlotter


class ARCOSExperimentRunner:
    """Main runner for ARCOS active learning experiments."""
    
    def __init__(self, config_path: str, policy: str, **kwargs):
        """Initialize experiment runner.
        
        Args:
            config_path: Path to configuration file
            policy: Acquisition policy (w1min, discmax)
            **kwargs: Additional arguments
        """
        self.config = load_config(config_path)
        self.policy = policy
        
        # Override config with kwargs
        for key, value in kwargs.items():
            if value is not None:
                self.config.update({key: value})
        
        # Set up system
        self._setup_system()
        
        # Initialize components
        self._initialize_components()
        
        # Create output directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = Path(self.config.get('output.save_dir', './outputs')) / f"{timestamp}_{policy}"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize plotter after output directory is created
        self.plotter = ARCOSPlotter(str(self.output_dir))
        
        # Save config
        self.config.save(self.output_dir / "config.yaml")
        
        print(f"ARCOS Experiment initialized with policy: {policy}")
        print(f"Output directory: {self.output_dir}")
    
    def _setup_system(self):
        """Set up system configuration."""
        system_config = self.config.get_system_config()
        
        # Set seed
        seed = system_config.get('seed', 1337)
        deterministic = system_config.get('deterministic', True)
        set_seed(seed, deterministic)
        set_deterministic_mode(deterministic)
        
        # Set device
        device = system_config.get('device', 'auto')
        if device == 'auto':
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        
        print(f"Using device: {self.device}")
        print(f"Seed: {seed}, Deterministic: {deterministic}")
    
    def _initialize_components(self):
        """Initialize experiment components."""
        # Get dataset class
        dataset_name = self.config.get('dataset.name')
        self.dataset_class = get_dataset_class(dataset_name)
        
        # Get transforms
        image_size = self.config.get('dataset.image_size', 224)
        self.train_transforms = get_transforms(image_size, is_training=True)
        self.eval_transforms = get_transforms(image_size, is_training=False)
        
        # Initialize feature cache
        cache_config = self.config.get_feature_cache_config()
        if isinstance(cache_config, dict) and cache_config.get('enabled', True):
            self.feature_cache = FeatureCache(
                cache_dir=cache_config.get('cache_dir', './cache'),
                refresh_qt_only=cache_config.get('refresh_qt_only', True)
            )
        else:
            self.feature_cache = None
    
    def _create_datasets_and_loaders(self):
        """Create datasets and data loaders."""
        dataset_config = self.config.get_dataset_config()
        training_config = self.config.get_training_config()
        
        # Source domains
        source_domains = dataset_config['source_domains']
        source_datasets = []
        for domain in source_domains:
            dataset = self.dataset_class(
                data_root=dataset_config['data_root'],
                domain=domain,
                split='train',
                transform=self.train_transforms,
                indices_file=dataset_config.get('indices_file'),
                toy_mode=dataset_config.get('toy_mode', False),
                toy_classes=dataset_config.get('toy_classes', 2),
                toy_samples_per_class=dataset_config.get('toy_samples_per_class', 200)
            )
            source_datasets.append(dataset)
        
        # Combine source datasets
        if len(source_datasets) > 1:
            from torch.utils.data import ConcatDataset
            source_dataset = ConcatDataset(source_datasets)
        else:
            source_dataset = source_datasets[0]
        
        # Target domain (pool)
        target_dataset = self.dataset_class(
            data_root=dataset_config['data_root'],
            domain=dataset_config['target_domain'],
            split='train',
            transform=self.train_transforms,
            indices_file=dataset_config.get('indices_file'),
            toy_mode=dataset_config.get('toy_mode', False),
            toy_classes=dataset_config.get('toy_classes', 2),
            toy_samples_per_class=dataset_config.get('toy_samples_per_class', 200)
        )
        
        # Anchor evaluation domain
        anchor_dataset = self.dataset_class(
            data_root=dataset_config['data_root'],
            domain=dataset_config['anchor_eval_domain'],
            split='test',
            transform=self.eval_transforms,
            indices_file=dataset_config.get('indices_file'),
            toy_mode=dataset_config.get('toy_mode', False),
            toy_classes=dataset_config.get('toy_classes', 2),
            toy_samples_per_class=dataset_config.get('toy_samples_per_class', 200)
        )
        
        # Create data loaders
        batch_size = training_config['batch_size']
        num_workers = training_config['num_workers']
        
        self.source_loader = create_dataloader(
            source_dataset, batch_size=batch_size, shuffle=True,
            num_workers=num_workers, pin_memory=True
        )
        
        self.target_loader = create_dataloader(
            target_dataset, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=True
        )
        
        self.anchor_loader = create_dataloader(
            anchor_dataset, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=True
        )
        
        print(f"Created datasets:")
        print(f"  Source: {len(source_dataset)} samples")
        print(f"  Target: {len(target_dataset)} samples")
        print(f"  Anchor: {len(anchor_dataset)} samples")
    
    def _create_models(self):
        """Create and initialize models."""
        model_config = self.config.get_model_config()
        dataset_config = self.config.get_dataset_config()
        
        # Create model Q (frozen)
        self.model_Q = get_backbone(
            backbone_name=model_config['backbone'],
            num_classes=dataset_config['num_classes'],
            pretrained=model_config['pretrained'],
            dropout=model_config['dropout'],
            feature_dim=model_config['feature_dim']
        )
        
        # Create model Q_tilde (fine-tuned)
        self.model_Qt = get_backbone(
            backbone_name=model_config['backbone'],
            num_classes=dataset_config['num_classes'],
            pretrained=model_config['pretrained'],
            dropout=model_config['dropout'],
            feature_dim=model_config['feature_dim']
        )
        
        # Initialize Q_tilde with Q weights
        self.model_Qt.load_state_dict(self.model_Q.state_dict())
        
        print(f"Created models with {dataset_config['num_classes']} classes")
    
    def _train_baseline_model(self):
        """Train baseline model Q on source domains."""
        print("Training baseline model Q on source domains...")
        
        training_config = self.config.get_training_config()
        model_config = self.config.get_model_config()
        
        # Freeze backbone initially
        freeze_epochs = model_config.get('freeze_backbone_epochs', 0)
        if freeze_epochs > 0:
            freeze_backbone_layers(self.model_Q, freeze_epochs)
            print(f"Froze backbone for first {freeze_epochs} epochs")
        
        # Create trainer
        print(f"Debug - learning_rate type: {type(training_config['learning_rate'])}, value: {training_config['learning_rate']}")
        print(f"Debug - weight_decay type: {type(training_config['weight_decay'])}, value: {training_config['weight_decay']}")
        
        trainer = Trainer(
            model=self.model_Q,
            train_loader=self.source_loader,
            val_loader=None,  # No validation for baseline
            device=self.device,
            learning_rate=float(training_config['learning_rate']),
            weight_decay=float(training_config['weight_decay']),
            mixed_precision=training_config['mixed_precision'],
            checkpoint_dir=str(self.output_dir / "checkpoints")
        )
        
        # Train model
        epochs = training_config['epochs']
        history = trainer.train(epochs=epochs, save_best=True)
        
        # Unfreeze all layers
        if freeze_epochs > 0:
            unfreeze_all_layers(self.model_Q)
            print("Unfroze all layers")
        
        # Continue training if needed
        if epochs > freeze_epochs:
            remaining_epochs = epochs - freeze_epochs
            print(f"Training for {remaining_epochs} more epochs with all layers...")
            history = trainer.train(epochs=remaining_epochs, save_best=True)
        
        # Update Q_tilde with final Q weights
        self.model_Qt.load_state_dict(self.model_Q.state_dict())
        
        print("Baseline model training completed")
        return history
    
    def _run_acquisition_rounds(self):
        """Run active learning acquisition rounds."""
        print("Starting acquisition rounds...")
        
        acquisition_config = self.config.get_acquisition_config()
        training_config = self.config.get_training_config()
        
        # Calculate rounds
        budget = acquisition_config['budgets'][0]  # Use first budget for now
        batch_size = acquisition_config['batch_size']
        total_samples = len(self.target_loader.dataset)
        total_labels = int(budget / 100 * total_samples)
        rounds = (total_labels + batch_size - 1) // batch_size
        
        print(f"Budget: {budget}%, Total labels: {total_labels}, Rounds: {rounds}")
        
        # Initialize acquisition policy
        policy = get_acquisition_policy(
            self.policy,
            num_projections=self.config.get('ot.num_projections', 256),
            normalize=self.config.get('ot.normalize_features', True)
        )
        
        # Initialize metrics storage
        all_metrics = []
        
        # Initial state
        labeled_indices = []
        unlabeled_indices = list(range(total_samples))
        labeled_features = torch.empty(0, self.config.get('model.feature_dim', 2048))
        
        # Extract initial features
        print("Extracting initial features...")
        source_features = self._extract_features(self.model_Q, self.source_loader)
        
        for round_idx in range(rounds):
            print(f"\n=== Round {round_idx + 1}/{rounds} ===")
            
            # Select samples
            if self.policy == 'w1min':
                selected_indices, updated_labeled_features = policy.select_samples(
                    unlabeled_indices=unlabeled_indices,
                    unlabeled_features=self._extract_features(self.model_Q, self.target_loader),
                    labeled_features=labeled_features,
                    source_features=source_features,
                    batch_size=batch_size
                )
            else:  # discmax
                selected_indices, updated_labeled_features = policy.select_samples(
                    unlabeled_indices=unlabeled_indices,
                    unlabeled_features=self._extract_features(self.model_Q, self.target_loader),
                    labeled_features=labeled_features,
                    source_features=source_features,
                    batch_size=batch_size,
                    model_Q=self.model_Q,
                    model_Qt=self.model_Qt,
                    unlabeled_loader=self._create_subset_loader(unlabeled_indices),
                    device=self.device
                )
            
            # Update labeled set
            labeled_indices.extend(selected_indices)
            unlabeled_indices = [i for i in unlabeled_indices if i not in selected_indices]
            labeled_features = updated_labeled_features
            
            # Fine-tune Q_tilde
            print("Fine-tuning Q_tilde...")
            fine_tune_loader = self._create_subset_loader(labeled_indices)
            
            fine_tuner = FineTuner(
                model=self.model_Qt,
                train_loader=fine_tune_loader,
                device=self.device,
                learning_rate=float(training_config['learning_rate']) * 0.1,  # Lower LR for fine-tuning
                weight_decay=float(training_config['weight_decay']),
                mixed_precision=training_config['mixed_precision']
            )
            
            fine_tune_epochs = training_config['fine_tune_epochs']
            fine_tune_history = fine_tuner.fine_tune(fine_tune_epochs)
            
            # Compute ARCOS metrics
            print("Computing ARCOS metrics...")
            metrics = compute_trace_metrics(
                model_Q=self.model_Q,
                model_Qt=self.model_Qt,
                source_loader=self.source_loader,
                target_loader=self._create_subset_loader(labeled_indices),
                anchor_loader=self.anchor_loader,
                device=self.device,
                ot_method=self.config.get('ot.estimator', 'max-sliced'),
                Lx_method=self.config.get('metrics.Lx_method', 'p99'),
                normalize_features=self.config.get('ot.normalize_features', True),
                num_projections=self.config.get('ot.num_projections', 256)
            )
            
            # Store round metrics
            round_metrics = {
                'round': round_idx + 1,
                'policy': self.policy,
                'budget': budget,
                'labeled_count': len(labeled_indices),
                'unlabeled_count': len(unlabeled_indices),
                'batch_size': batch_size,
                **metrics
            }
            
            all_metrics.append(round_metrics)
            
            print(f"Round {round_idx + 1} completed:")
            print(f"  Labeled samples: {len(labeled_indices)}")
            print(f"  |ΔR|: {metrics['delta_R']:.4f}")
            print(f"  W1: {metrics['W1']:.4f}")
            print(f"  Output discrepancy: {metrics['output_discrepancy']:.4f}")
            print(f"  Bound proxy: {metrics['bound_proxy']:.4f}")
        
        # Save metrics
        metrics_df = pd.DataFrame(all_metrics)
        metrics_path = self.output_dir / "metrics.csv"
        metrics_df.to_csv(metrics_path, index=False)
        print(f"Saved metrics to: {metrics_path}")
        
        return metrics_df
    
    def _extract_features(self, model: nn.Module, loader) -> torch.Tensor:
        """Extract features from model."""
        model.eval()
        features_list = []
        
        with torch.no_grad():
            for data, _ in loader:
                data = data.to(self.device)
                if hasattr(model, 'get_features'):
                    features = model.get_features(data)
                else:
                    features = model.backbone(data)
                    features = features.view(features.size(0), -1)
                features_list.append(features.cpu())
        
        return torch.cat(features_list, dim=0)
    
    def _create_subset_loader(self, indices: List[int]):
        """Create data loader for subset of indices."""
        from torch.utils.data import Subset
        subset_dataset = Subset(self.target_loader.dataset, indices)
        return create_dataloader(
            subset_dataset,
            batch_size=self.config.get('training.batch_size', 64),
            shuffle=True,
            num_workers=self.config.get('training.num_workers', 4)
        )
    
    def run_experiment(self):
        """Run the complete ARCOS experiment."""
        print("Starting ARCOS experiment...")
        start_time = time.time()
        
        try:
            # Create datasets and loaders
            self._create_datasets_and_loaders()
            
            # Create models
            self._create_models()
            
            # Train baseline model
            baseline_history = self._train_baseline_model()
            
            # Run acquisition rounds
            metrics_df = self._run_acquisition_rounds()
            
            # Create plots
            print("Creating plots...")
            self.plotter.plot_all_curves(metrics_df)
            
            # Save experiment state
            experiment_state = {
                'policy': self.policy,
                'config': self.config.config,
                'baseline_history': baseline_history,
                'final_metrics': metrics_df.to_dict('records')
            }
            
            checkpoint_manager = CheckpointManager(str(self.output_dir), "experiment")
            checkpoint_manager.save_experiment_state(experiment_state)
            
            total_time = time.time() - start_time
            print(f"\nARCOS experiment completed successfully!")
            print(f"Total time: {total_time:.2f} seconds")
            print(f"Results saved to: {self.output_dir}")
            
        except Exception as e:
            print(f"Experiment failed with error: {e}")
            raise


def main():
    """Main function for command-line usage."""
    parser = argparse.ArgumentParser(description="Run ARCOS active learning experiment")
    parser.add_argument("--config", type=str, required=True, help="Path to configuration file")
    parser.add_argument("--policy", type=str, required=True, choices=["w1min", "discmax"], help="Acquisition policy")
    
    # Model arguments
    parser.add_argument("--backbone", type=str, help="Backbone architecture")
    parser.add_argument("--epochs", type=int, help="Number of training epochs")
    parser.add_argument("--lr", type=float, help="Learning rate")
    parser.add_argument("--batch-size", type=int, help="Batch size")
    parser.add_argument("--num-workers", type=int, help="Number of workers")
    
    # OT arguments
    parser.add_argument("--ot.estimator", type=str, help="OT estimator method")
    parser.add_argument("--ot.K", type=int, help="Number of projections for sliced OT")
    
    # Feature cache arguments
    parser.add_argument("--feature-cache", action="store_true", help="Enable feature caching")
    
    # System arguments
    parser.add_argument("--seed", type=int, help="Random seed")
    
    args = parser.parse_args()
    
    # Convert arguments to config updates
    config_updates = {}
    for arg, value in vars(args).items():
        if value is not None and arg not in ['config', 'policy']:
            config_updates[arg] = value
    
    # Run experiment
    runner = ARCOSExperimentRunner(args.config, args.policy, **config_updates)
    runner.run_experiment()


if __name__ == "__main__":
    main()

