#!/usr/bin/env python3
"""
Script to load a trained model along with datasets and dataloaders.

This script emulates exactly what was done in the training run by reusing
functions from main.py to ensure consistency.

Usage:
    python load_trained_model.py --save_dir /path/to/saved/model --config_name config_name [additional_overrides]
"""

import os
import sys
import pickle
import argparse
from pathlib import Path
import torch
import omegaconf
from omegaconf import OmegaConf
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate

# Add the current directory to the path so we can import from main.py
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# Import functions from main.py
from main import (
    get_dataset, 
    get_datasets_from_inds, 
    get_loaders, 
    get_model, 
    get_canon_model,
    set_seed,
    data_dir,
    device
)

def load_config_from_file(config_path):
    """
    Load the configuration from a YAML or pickle file.
    Args:
        config_path (str): Path to the config file (YAML or pickle)
    Returns:
        config: Configuration object
    """
    if config_path.endswith('.pkl'):
        with open(config_path, 'rb') as f:
            config = pickle.load(f)
    elif config_path.endswith(('.yaml', '.yml')):
        config = OmegaConf.load(config_path)
    else:
        raise ValueError(f"Unsupported config file type: {config_path}")
    return config

def load_model_from_checkpoint(save_dir, cfg):
    """
    Load the trained model from checkpoint.
    
    Args:
        save_dir (str): Path to the saved model directory
        cfg: Configuration object
        
    Returns:
        tuple: (model, canon_model, optimizer)
    """
    # Get canonicalization model if needed
    add_canon_params = False
    canon_model = None
    if cfg.dataset.task == 'task_dependent':
        c_args = cfg.dataset.task_dependent_args.c_args
        if c_args.learned:
            add_canon_params = True
        canon_model = get_canon_model(cfg)
    
    # Get main model
    model = get_model(cfg.model)
    
    # Create optimizer (needed for loading state dict)
    if add_canon_params:
        optimizer = torch.optim.Adam(list(model.parameters()) + list(canon_model.parameters()), lr=cfg.training.learning_rate)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.learning_rate)
    
    # Load checkpoint
    checkpoint_path = os.path.join(save_dir, 'model.pt')
    if os.path.islink(checkpoint_path):
        # Follow symbolic link
        checkpoint_path = os.path.realpath(checkpoint_path)
    
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Model checkpoint not found: {checkpoint_path}")
    
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    if add_canon_params and canon_model is not None:
        canon_model.load_state_dict(checkpoint['canon_model_state_dict'])
    
    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Set models to evaluation mode
    model.eval()
    if canon_model is not None:
        canon_model.eval()
    
    return model, canon_model, optimizer

def setup_datasets_and_loaders(cfg, dataset_dir):
    """
    Set up datasets and dataloaders exactly as in training.
    
    Args:
        cfg: Configuration object
        dataset_dir (str): Path to dataset directory
        
    Returns:
        tuple: (datasets, dataloaders, criterion, aux_criteria)
    """
    # Get dataset, criterion, aux_criteria, and indices
    dataset, criterion, aux_criteria, absolute_inds, label_operator = get_dataset(cfg, dataset_dir=dataset_dir)
    
    # Create subsets
    train_set, val_set, test_set = get_datasets_from_inds(absolute_inds, dataset)
    datasets = {'train': train_set, 'val': val_set, 'test': test_set}
    
    # Create dataloaders
    dataloaders = get_loaders(
        datasets, 
        batch_size=cfg.training.batch_size, 
        num_workers=cfg.training.num_workers, 
        dataset_type=cfg.dataset.name
    )
    
    return datasets, dataloaders, criterion, aux_criteria

def main():
    parser = argparse.ArgumentParser(description='Load a trained model with datasets and dataloaders')
    parser.add_argument('--save_dir', type=str, required=True, 
                       help='Path to the saved model directory')
    parser.add_argument('--config_path', type=str, required=True,
                       help='Path to the config file (YAML or pickle)')
    parser.add_argument('--overrides', nargs='*', default=[],
                       help='Additional config overrides (e.g., ++training.batch_size=32)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed for reproducibility')
    
    args = parser.parse_args()
    
    # Set seed for reproducibility
    set_seed(args.seed)
    
    # Load configuration from provided file
    print(f"Loading configuration from: {args.config_path}")
    cfg = load_config_from_file(args.config_path)
    
    # Apply any additional overrides
    if args.overrides:
        print(f"Applying overrides: {args.overrides}")
        with hydra.initialize_config_dir(version_base=None, config_dir="configs"):
            temp_cfg = hydra.compose(config_name=os.path.basename(args.config_path).replace('.yaml','').replace('.yml',''), overrides=args.overrides)
            for key, value in OmegaConf.to_container(temp_cfg, resolve=True).items():
                if key in cfg:
                    cfg[key] = value
                else:
                    setattr(cfg, key, value)
    
    # Set up dataset directory
    if data_dir not in cfg.dataset:
        dataset_dir = os.path.join(data_dir, cfg.dataset.directory_name)
    else:
        dataset_dir = cfg.dataset.data_dir
    
    print(f"Dataset directory: {dataset_dir}")
    
    # Load model
    print("Loading model...")
    model, canon_model, optimizer = load_model_from_checkpoint(args.save_dir, cfg)
    
    # Set up datasets and dataloaders
    print("Setting up datasets and dataloaders...")
    datasets, dataloaders, criterion, aux_criteria = setup_datasets_and_loaders(cfg, dataset_dir)
    
    # Print summary
    print("\n" + "="*50)
    print("LOADED MODEL SUMMARY")
    print("="*50)
    print(f"Model: {cfg.model.name}")
    print(f"Dataset: {cfg.dataset.name}")
    print(f"Task: {cfg.dataset.task}")
    print(f"Device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    if canon_model is not None:
        print(f"Canonicalization model parameters: {sum(p.numel() for p in canon_model.parameters()):,}")
    
    print(f"\nDataset sizes:")
    for split, dataset in datasets.items():
        print(f"  {split}: {len(dataset)} samples")
    
    print(f"\nDataloader batch sizes:")
    for split, loader in dataloaders.items():
        print(f"  {split}: {loader.batch_size}")
    
    print("\nModel and data are ready for use!")
    print("You can now use:")
    print("  - model: the trained model")
    print("  - canon_model: the canonicalization model (if applicable)")
    print("  - dataloaders: dictionary with 'train', 'val', 'test' dataloaders")
    print("  - datasets: dictionary with 'train', 'val', 'test' datasets")
    print("  - criterion: loss function")
    print("  - aux_criteria: auxiliary metrics")
    
    # Return the loaded objects for interactive use
    return {
        'model': model,
        'canon_model': canon_model,
        'optimizer': optimizer,
        'dataloaders': dataloaders,
        'datasets': datasets,
        'criterion': criterion,
        'aux_criteria': aux_criteria,
        'cfg': cfg
    }

if __name__ == "__main__":
    loaded_objects = main()
    
    # Make objects available for interactive use
    globals().update(loaded_objects)
    
    print("\nObjects are now available in the global namespace for interactive use.") 