import torch
import random
import numpy as np
from pathlib import Path
import json
from datetime import datetime
from typing import Any, Dict, Union, Optional

def set_seed(seed: int) -> None:
    """Set all random seeds for reproducibility.
    
    Args:
        seed: Integer to use as random seed
    """
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def convert_tensors_to_serializable(obj: Any) -> Any:
    """Convert torch tensors to serializable format.
    
    Recursively converts torch tensors to lists for JSON serialization.
    
    Args:
        obj: Object potentially containing torch tensors
        
    Returns:
        Same object with all torch tensors converted to lists
    """
    if isinstance(obj, torch.Tensor):
        return obj.cpu().detach().numpy().tolist()
    elif isinstance(obj, dict):
        return {key: convert_tensors_to_serializable(value) 
                for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_tensors_to_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_tensors_to_serializable(item) for item in obj)
    return obj

def log_scale_experiment(results_dir: Path, description: str = ""):
    """
    Log scale experiment details to a central experiments log file.
    
    Args:
        results_dir: Path to experiment results directory (should contain directory name with experiment details)
        description: Optional description of the experiment purpose/configuration
    """
    BASE_RESULTS_DIR = Path("./results")
    log_file = BASE_RESULTS_DIR / "scale_experiments_log.csv"
    
    # Create header if file doesn't exist
    if not log_file.exists():
        with open(log_file, 'w') as f:
            f.write("timestamp,directory,training_mode,scale,optimizer,seed,description\n")
    
    # Extract experiment details from directory name
    # Expected format: cpu_experiment_{training_mode}_scale{scale}_{optimizer}_seed{seed}
    dir_name = results_dir.name
    
    # Parse directory name using regex
    import re
    pattern = r"cpu_experiment_([a-z_0-9]+)_scale([0-9.]+)_([a-z]+)_seed([0-9]+)"
    match = re.match(pattern, dir_name)
    
    if match:
        training_mode, scale, optimizer, seed = match.groups()
        
        # Extract timestamp from directory name if it follows the batch format
        # For batch directories: slurm_cpu_batch_YYYYMMDD_HHMMSS
        parent_dir = results_dir.parent.name
        timestamp_match = re.search(r"(\d{8}_\d{6})", parent_dir)
        if timestamp_match:
            timestamp = timestamp_match.group(1)
        else:
            # Fallback to current timestamp
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    else:
        # Fallback parsing - try to extract what we can
        print(f"Warning: Could not parse directory name '{dir_name}' with expected pattern")
        print("Expected format: cpu_experiment_{{training_mode}}_scale{{scale}}_{{optimizer}}_seed{{seed}}")
        
        # Try to extract basic info
        parts = dir_name.split('_')
        training_mode = "unknown"
        scale = "unknown"
        optimizer = "unknown" 
        seed = "unknown"
        
        for i, part in enumerate(parts):
            if part.startswith('scale'):
                scale = part.replace('scale', '')
            elif part in ['adamw', 'sgd']:
                optimizer = part
            elif part.startswith('seed'):
                seed = part.replace('seed', '')
            elif part in ['minimal', 'balanced', 'maximal'] or part.startswith('balanced_'):
                if part.startswith('balanced_') and i+1 < len(parts) and parts[i+1].isdigit():
                    training_mode = f"{part}_{parts[i+1]}"
                else:
                    training_mode = part
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Log experiment
    with open(log_file, 'a') as f:
        f.write(f"{timestamp},{dir_name},{training_mode},{scale},{optimizer},{seed},\"{description}\"\n")
    
    print(f"Logged experiment to {log_file}")

def load_experiment_results(results_dir: Path) -> Dict:
    """
    Load all results from an experiment directory.
    
    Args:
        results_dir: Path to experiment directory
    Returns:
        Dictionary with processed results across all seeds
    """
    # Load the complete results file
    with open(results_dir / "all_results.json", 'r') as f:
        all_results = json.load(f)
    
    return all_results

def load_multi_experiment_results(
    base_dir: Path, 
    experiment_dirs: Dict[str, Dict[str, str]],
    target_rank: Optional[int] = None
) -> Dict:
    """
    Load and combine results from multiple experiment directories. [LC] This is because I rank experiment with full rank
    once and all other ranks together once.
    
    Args:
        base_dir: Base directory containing all experiment folders
        experiment_dirs: Dictionary mapping training modes to their experiment details
            Format: {
                'training_mode': {
                    'full_rank': 'timestamp_for_rank128',
                    'other_ranks': 'timestamp_for_other_ranks'
                }
            }
        target_rank: If specified, only include data for this rank
    
    Returns:
        Combined results dictionary
    """
    combined_results = {
        'config': {
            'training_modes': list(experiment_dirs.keys()),
            'ranks': [],
            'seeds': None
        },
        'results': {}
    }
    
    for training_mode, timestamps in experiment_dirs.items():
        combined_results['results'][training_mode] = {}
        
        # Load full rank results
        if timestamps.get('full_rank'):
            full_rank_dir = base_dir / f"experiment_{timestamps['full_rank']}"
            with open(full_rank_dir / "all_results.json", 'r') as f:
                full_rank_data = json.load(f)
                
            # Store configuration if not already set
            if combined_results['config']['seeds'] is None:
                combined_results['config']['seeds'] = full_rank_data['config']['seeds']
            
            # Extract rank (normally 128 for full rank)
            full_rank = '128'  # Default assumption
            if 'ranks' in full_rank_data['config'] and len(full_rank_data['config']['ranks']) > 0:
                full_rank = str(full_rank_data['config']['ranks'][0])
            
            # Skip if this isn't the target rank
            if target_rank is not None and int(full_rank) != target_rank:
                continue
                
            # Add full rank data
            for seed in full_rank_data['config']['seeds']:
                if full_rank not in combined_results['results'][training_mode]:
                    combined_results['results'][training_mode][full_rank] = {}
                
                seed_str = str(seed)
                if seed_str in full_rank_data['results'][full_rank]:
                    combined_results['results'][training_mode][full_rank][seed_str] = \
                        full_rank_data['results'][full_rank][seed_str]
        
        # Load other ranks results
        if timestamps.get('other_ranks'):
            other_ranks_dir = base_dir / f"experiment_{timestamps['other_ranks']}"
            with open(other_ranks_dir / "all_results.json", 'r') as f:
                other_ranks_data = json.load(f)
            
            # Add data for other ranks
            for rank in other_ranks_data['config']['ranks']:
                rank_str = str(rank)
                
                # Skip if this isn't the target rank
                if target_rank is not None and int(rank_str) != target_rank:
                    continue
                    
                if rank_str not in combined_results['results'][training_mode]:
                    combined_results['results'][training_mode][rank_str] = {}
                    
                for seed in other_ranks_data['config']['seeds']:
                    seed_str = str(seed)
                    if seed_str in other_ranks_data['results'][rank_str]:
                        combined_results['results'][training_mode][rank_str][seed_str] = \
                            other_ranks_data['results'][rank_str][seed_str]
    
    # Compile unique ranks across all experiments
    all_ranks = set()
    for training_mode in combined_results['results']:
        all_ranks.update(combined_results['results'][training_mode].keys())
    combined_results['config']['ranks'] = sorted([int(r) for r in all_ranks])
    
    return combined_results

def load_model_for_evaluation(model_path, hidden_size, rank, device=None):
    # If no device is specified, use CUDA if available
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    model = RankControlledRNN(
        input_size=env.config.input_size,
        hidden_size=hidden_size,
        output_size=env.config.output_size,
        rank=rank
    )
    
    # Load the model weights with proper device mapping
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # Move model to the appropriate device
    model = model.to(device)
    model.eval()
    
    return model

def load_hidden_activities(file_path: Path) -> Dict:
    """
    Load hidden activities from a saved PT file.
    
    Args:
        file_path: Path to the saved hidden activities file
        
    Returns:
        Dictionary containing the hidden activities organized by timepoint and task
    """
    if not file_path.exists():
        raise FileNotFoundError(f"Hidden activities file not found at {file_path}")
    
    # Load the hidden activities
    hidden_activities = torch.load(file_path)
    
    # Optional validation
    expected_keys = ['timepoints', 'task_indices', 'hidden_states', 'task_info']
    for key in expected_keys:
        if key not in hidden_activities:
            raise ValueError(f"Missing expected key '{key}' in hidden activities file")
    
    return hidden_activities