import torch
from torch import nn
from typing import Dict, Any
from model import (
    DiscreteTimeModel, ContinuousTimeModel, 
    LinearDynamicsCell, MLPDynamicsCell, MiniMLPDynamicsCell
)

def set_heatmap_from_readout_type(config: Dict[str, Any]) -> None:
    """Set heatmap flag based on readout_type. Readout_type is the primary control."""
    readout_type = config.get('readout_type', 'separate')
    
    if 'monotonic' in readout_type:
        config['heatmap'] = True
        print(f"Setting heatmap=True for readout_type='{readout_type}'")
    else:
        config['heatmap'] = False
        if config.get('heatmap', False):  # If user explicitly set heatmap=True
            print(f"Overriding heatmap=False for readout_type='{readout_type}'")


def create_dynamics_cell(config: Dict[str, Any]) -> nn.Module:
    """Create RNN cell based on configuration."""
    input_size = 1  # Default for most cells
    hidden_size = config['latent_size']
    dynamics_model_type = config['dynamics_model_type']
    
    if dynamics_model_type == 'rnn':
        return nn.RNNCell(input_size, hidden_size,
                          nonlinearity=config['dynamics_nonlinearity'])
    elif dynamics_model_type == 'gru': 
        return nn.GRUCell(input_size, hidden_size)
    elif dynamics_model_type == 'linear_unconstraint':
        return LinearDynamicsCell(
            input_size=input_size,
            hidden_size=hidden_size,
            dt=config['dynamics_dt'],
            mlp_hidden_dim=config['dynamics_hidden_dim']
        )
    elif dynamics_model_type == 'linear_constraint':
        return MiniMLPDynamicsCell(
            input_size=config.get('external_input_size', 1),
            hidden_size=1,
            bias=True,
            dt=config['dynamics_dt'],
            mlp_hidden_dims=config['dynamics_hidden_dim'],
            mlp_activation=None,
            neurons=config['neurons'],
            noise=config['noise'],
            softplus=config['dynamics_monotonic'],
            latent_sizes=config.get('population_latent_sizes'),
            compositional_func=config.get('compositional_func', False),
            stimulated_populations=config.get('stimulated_populations')
        )
    elif dynamics_model_type == 'mlp':
        return MLPDynamicsCell(
            input_size=input_size,
            hidden_size=hidden_size,
            dt=config['dynamics_dt'],
            mlp_hidden_dims=config['dynamics_hidden_dim'],
            mlp_activation=config['dynamics_nonlinearity']
        )
    elif dynamics_model_type == 'minimlp':
        return MiniMLPDynamicsCell(
            input_size=config.get('external_input_size', 1), 
            hidden_size=1,
            bias=True,
            dt=config['dynamics_dt'],
            mlp_hidden_dims=config['dynamics_hidden_dim'],
            mlp_activation=config['dynamics_nonlinearity'],
            neurons=config['neurons'],
            noise=config['noise'],
            softplus=config['dynamics_monotonic'],
            latent_sizes=config.get('population_latent_sizes'),
            compositional_func=config.get('compositional_func', False),
            stimulated_populations=config.get('stimulated_populations')
        )
    else:
        raise ValueError(f"Unknown dynamics_model_type: {dynamics_model_type}")


def create_model(config: Dict[str, Any]) -> nn.Module:
    """Create model based on configuration."""
    # Set heatmap based on readout_type (readout_type is primary control)
    set_heatmap_from_readout_type(config)
    
    # --- Validation ---
    solver_type = config['solver_type']
    dynamics_model_type = config['dynamics_model_type']

    if solver_type == 'continuous' and dynamics_model_type in ['rnn', 'gru']:
        raise ValueError("Continuous solver cannot be used with " + 
                         f"'{dynamics_model_type}' dynamics model. " +
                         "Use 'linear', 'mlp', or 'minimlp'.")

    # Create appropriate cell
    cell = create_dynamics_cell(config)
    
    # Create model based on solver type
    if solver_type == 'discrete':
        model = DiscreteTimeModel(cell, config, noise=config['noise'])
    elif solver_type == 'continuous':
        model = ContinuousTimeModel(cell, config, noise=config['noise'])
    else:
        raise ValueError(f"Unknown solver_type: {solver_type}")
    
    return model


def create_optimizer(model: nn.Module, config: Dict[str, Any]) -> torch.optim.Optimizer:
    """Create optimizer for the model."""
    return torch.optim.Adam(
        model.parameters(), 
        lr=config['lr'], 
        weight_decay=config['weight_decay']
    )


def setup_directories(config: Dict[str, Any], experiment_name: str) -> Dict[str, str]:
    """Set up directory structure for experiment outputs."""
    import os
    from datetime import datetime
    import uuid
    
    # Generate unique ID for this run
    unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}"
    
    # Create base directory structure organized by dataset
    base_dir = config.get('output_dir', './results')
    
    # Create subdirectories
    dirs = {
        'BASE_PATH': os.path.join(base_dir, 'plots'),
        'CHECKPOINT_PATH': os.path.join(base_dir, 'checkpoints'),
        'LOG_PATH': os.path.join(base_dir, 'logs'),
        'METRICS_PATH': os.path.join(base_dir, 'metrics'),
        'CONFIG_PATH': os.path.join(base_dir, 'config')
    }
    
    # Create all directories
    for dir_path in dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    
    return dirs


def save_experiment_config(config: Dict[str, Any], experiment_name: str) -> str:
    """Save experiment configuration to JSON file."""
    import json
    import os
    from datetime import datetime
    
    # Create a clean config for saving (remove non-serializable objects)
    clean_config = {}
    for key, value in config.items():
        try:
            # Test if the value is JSON serializable
            json.dumps(value)
            clean_config[key] = value
        except (TypeError, ValueError):
            # Convert non-serializable objects to string representation
            clean_config[key] = str(value)
    
    # Add metadata
    clean_config['_metadata'] = {
        'experiment_name': experiment_name,
        'saved_at': datetime.now().isoformat(),
        'config_version': '1.0'
    }
    
    # Save to config directory
    config_file = os.path.join(config['CONFIG_PATH'], f"{experiment_name}_config.json")
    
    with open(config_file, 'w') as f:
        json.dump(clean_config, f, indent=4, sort_keys=True)
    
    print(f"Experiment configuration saved to: {config_file}")
    return config_file


def load_model_from_checkpoint(checkpoint_path: str, config: Dict[str, Any]) -> tuple:
    """Load model from checkpoint.
    
    Returns:
        tuple: (model, model_config) where model_config is the configuration used to create the model
    """

    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    # Start with saved config if available, otherwise use provided config
    if 'config' in checkpoint:
        print("Found model configuration in checkpoint")
        model_config = checkpoint['config'].copy()
    else:
        print("No model configuration found in checkpoint, using provided config")
        model_config = config.copy()
    
    # Always override some architecture parameters if explicitly provided in config
    architecture_keys = ['solver_type', 'dynamics_model_type', 'readout_type', 'latent_size']
    # Do NOT override latent_size so that checkpoint architecture is preserved
    # architecture_keys = ['solver_type', 'dynamics_model_type', 'readout_type']
    overridden_keys = []
    
    for key in architecture_keys:
        if key in config and (key not in model_config or config[key] != model_config.get(key)):
            model_config[key] = config[key]
            overridden_keys.append(key)

    # Dynamically set decoding parameters (optional)
    for k in ['causal_model', 'window_size', 'ic_window_size', 'step_size']:
        if k in config:
            model_config[k] = config[k]
            overridden_keys.append(k)
    
    if overridden_keys:
        print(f"Overriding saved config with provided values for: {overridden_keys}")
    
    # Always preserve runtime settings from the provided config
    runtime_keys = ['device', 'plot', 'heatmap', 'BASE_PATH', 'CHECKPOINT_PATH', 
                   'LOG_PATH', 'METRICS_PATH', 'CONFIG_PATH']
    for key in runtime_keys:
        if key in config:
            model_config[key] = config[key]
    
    # Create model with final architecture
    model = create_model(model_config)
    
    # Load weights
    model.load_state_dict(checkpoint['weight'])
    
    return model, model_config 