import torch
import json
import os
from typing import Dict, List, Tuple, Optional
from torch.utils.data import Subset
from torchvision import datasets

from model.vae_models.vae import VAE
from model.vae_models.conv_vae import ConvVAE
from model.vae_models.hierarchical_vae import HierarchicalMLPVAE, HierarchicalConvVAE


def load_trained_model(experiment_dir: str, split_idx: int, device: torch.device) -> Tuple[VAE, Dict]:
    """
    Load a trained VAE model from experiment directory.
    
    Args:
        experiment_dir: Path to experiment directory
        split_idx: Index of the split to load
        device: Device to load model on
    
    Returns:
        Tuple of (model, config)
    """
    # Load config (fail fast with clear message)
    config_file = os.path.join(experiment_dir, 'aggregated_results.json')
    if not os.path.exists(config_file):
        raise FileNotFoundError(
            f"aggregated_results.json not found at: {config_file}. "
            f"Training might have failed or not completed."
        )
    with open(config_file, 'r') as f:
        aggregated_results = json.load(f)
    config = aggregated_results['config']
    
    # Initialize model
    arch = config.get('model', {}).get('arch', 'mlp')
    if arch == 'cnn':
        model = ConvVAE(
            in_channels=config['model'].get('in_channels', 1),
            encoder_channels=config['model']['encoder_channels'],
            decoder_channels=config['model']['decoder_channels'],
            latent_dim=config['model']['latent_dim'],
        )
    elif arch == 'hcnn':
        model = HierarchicalConvVAE(
            in_channels=config['model'].get('in_channels', 1),
            encoder_channels=config['model']['encoder_channels'],
            decoder_channels=config['model']['decoder_channels'],
            latent_dims=config['model']['latent_dims'],
        )
    elif arch == 'hmlp':
        model = HierarchicalMLPVAE(
            input_dim=config['model']['input_dim'],
            hidden_dims=config['model']['hidden_dims'],
            latent_dims=config['model']['latent_dims'],
            encoder_hidden_dims=config.get('model', {}).get('encoder_hidden_dims'),
            decoder_hidden_dims=config.get('model', {}).get('decoder_hidden_dims'),
        )
    else:
        model = VAE(
            input_dim=config['model']['input_dim'],
            hidden_dims=config['model']['hidden_dims'],
            latent_dim=config['model']['latent_dim'],
            encoder_hidden_dims=config.get('model', {}).get('encoder_hidden_dims'),
            decoder_hidden_dims=config.get('model', {}).get('decoder_hidden_dims'),
        )
    
    # Load model weights
    base_model_dir = config.get('logging', {}).get('model_dir', None)
    if base_model_dir is None:
        # Fallback to old layout (under experiment_dir)
        base_model_dir = experiment_dir

    candidate_dirs = []
    # Case 1: base_model_dir already points to a split directory -> replace the split suffix
    base_name = os.path.basename(base_model_dir.rstrip('/'))
    if base_name.startswith('split_'):
        parent_dir = os.path.dirname(base_model_dir.rstrip('/'))
        candidate_dirs.append(os.path.join(parent_dir, f'split_{split_idx}'))
    else:
        candidate_dirs.append(os.path.join(base_model_dir, f'split_{split_idx}'))
    # Case 2: fallback to experiment_dir/split_{idx}
    candidate_dirs.append(os.path.join(experiment_dir, f'split_{split_idx}'))

    last_error: Optional[Exception] = None
    model_file = None
    tried_paths: List[str] = []
    for dir_path in candidate_dirs:
        best_path = os.path.join(dir_path, 'best_model.pth')
        tried_paths.append(best_path)
        if os.path.exists(best_path):
            model_file = best_path
            break
        # Try checkpoints as fallback
        if os.path.isdir(dir_path):
            candidates = [
                os.path.join(dir_path, fname)
                for fname in os.listdir(dir_path)
                if fname.startswith('checkpoint_epoch_') and fname.endswith('.pth')
            ]
            if candidates:
                candidates.sort()
                model_file = candidates[-1]
                break
    if model_file is None:
        raise FileNotFoundError(
            "Model file not found. Tried: " + ", ".join(tried_paths)
        )
    
    checkpoint = torch.load(model_file, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model, config


def load_experiment_data_splits(experiment_dir: str) -> Tuple[List[Tuple[List[int], List[int]]], Dict]:
    """
    Load data splits for an experiment.
    
    Args:
        experiment_dir: Path to experiment directory
    
    Returns:
        Tuple of (splits, metadata)
    """
    data_splits_dir = os.path.join(experiment_dir, 'data_splits')
    metadata_file = os.path.join(data_splits_dir, 'experiment_metadata.json')
    
    # Load metadata
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    
    # Load all splits
    splits = []
    for i in range(metadata['num_splits']):
        split_file = os.path.join(data_splits_dir, f'split_{i}.json')
        with open(split_file, 'r') as f:
            split_data = json.load(f)
        splits.append((split_data['train_indices'], split_data['val_indices']))
    
    return splits, metadata


def recreate_data_subsets(dataset: datasets.MNIST, 
                         train_indices: List[int], 
                         val_indices: List[int]) -> Tuple[Subset, Subset]:
    """
    Recreate train and validation subsets from indices.
    
    Args:
        dataset: Full dataset
        train_indices: Training indices
        val_indices: Validation indices
    
    Returns:
        Tuple of (train_subset, val_subset)
    """
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    
    return train_subset, val_subset


def get_experiment_info(experiment_dir: str) -> Dict:
    """
    Get information about an experiment.
    
    Args:
        experiment_dir: Path to experiment directory
    
    Returns:
        Dictionary with experiment information
    """
    # Load aggregated results
    aggregated_file = os.path.join(experiment_dir, 'aggregated_results.json')
    with open(aggregated_file, 'r') as f:
        aggregated_results = json.load(f)
    
    # Load data splits metadata
    data_splits_dir = os.path.join(experiment_dir, 'data_splits')
    metadata_file = os.path.join(data_splits_dir, 'experiment_metadata.json')
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    
    # Combine information
    experiment_info = {
        'experiment_dir': experiment_dir,
        'config': aggregated_results['config'],
        'average_metrics': aggregated_results['average_metrics'],
        'individual_results': aggregated_results['individual_results'],
        'data_metadata': metadata
    }
    
    return experiment_info


def list_experiments(base_dir: str) -> List[Dict]:
    """
    List all experiments in the base directory.
    
    Args:
        base_dir: Base directory containing experiments
    
    Returns:
        List of experiment information dictionaries
    """
    experiments = []
    
    for dataset_dir in os.listdir(base_dir):
        dataset_path = os.path.join(base_dir, dataset_dir)
        if not os.path.isdir(dataset_path):
            continue
            
        for model_dir in os.listdir(dataset_path):
            model_path = os.path.join(dataset_path, model_dir)
            if not os.path.isdir(model_path):
                continue
                
            for experiment_dir in os.listdir(model_path):
                experiment_path = os.path.join(model_path, experiment_dir)
                if not os.path.isdir(experiment_path):
                    continue
                
                try:
                    experiment_info = get_experiment_info(experiment_path)
                    experiments.append(experiment_info)
                except Exception as e:
                    print(f"Error loading experiment {experiment_path}: {e}")
    
    return experiments 