import os
import json
import time
import torch
import numpy as np
from datetime import datetime
from typing import Dict, List, Tuple, Optional

def save_metrics(metrics: Dict, epoch: int, save_dir: str = "metrics"):
    """
    Save training metrics to JSON file
    
    Args:
        metrics: Dictionary of training metrics
        epoch: Current epoch number
        save_dir: Directory to save metrics
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Convert numpy arrays to lists for JSON serialization
    serializable_metrics = {}
    for key, value in metrics.items():
        if hasattr(value, 'tolist'):  # numpy array
            serializable_metrics[key] = value.tolist()
        elif isinstance(value, list):
            # Handle lists that might contain numpy types
            converted_list = []
            for item in value:
                if isinstance(item, np.floating):  # numpy float32, float64, etc.
                    converted_list.append(float(item))
                elif isinstance(item, np.integer):  # numpy int32, int64, etc.
                    converted_list.append(int(item))
                elif isinstance(item, np.bool_):  # numpy bool
                    converted_list.append(bool(item))
                else:
                    converted_list.append(item)
            serializable_metrics[key] = converted_list
        elif isinstance(value, np.floating):  # Handle single numpy floats
            serializable_metrics[key] = float(value)
        elif isinstance(value, np.integer):  # Handle single numpy ints
            serializable_metrics[key] = int(value)
        elif isinstance(value, np.bool_):  # Handle single numpy bools
            serializable_metrics[key] = bool(value)
        else:
            serializable_metrics[key] = value
    
    # Save current metrics
    metrics_path = os.path.join(save_dir, f"metrics_epoch_{epoch}.json")
    with open(metrics_path, 'w') as f:
        json.dump(serializable_metrics, f, indent=2)
    
    # Also save as latest
    latest_path = os.path.join(save_dir, "latest_metrics.json")
    with open(latest_path, 'w') as f:
        json.dump(serializable_metrics, f, indent=2)
    
    print(f"  📊 Metrics saved to {metrics_path}")


def save_messages(sender_agent, dataset, device, epoch: int, save_dir: str = "messages", num_samples: int = 100):
    """
    Save encoded messages from sender agent for analysis
    
    Args:
        sender_agent: Trained sender agent
        dataset: Dataset to sample from
        device: Compute device
        epoch: Current epoch number
        save_dir: Directory to save messages
        num_samples: Number of samples to encode
    """
    os.makedirs(save_dir, exist_ok=True)
    
    sender_agent.eval()
    
    # Sample random images from dataset
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    messages = []
    labels = []
    
    with torch.no_grad():
        for idx in indices:
            img, label = dataset[idx]
            img = img.to(device)
            
            # Get message from sender
            message = sender_agent.send_message(img)
            
            # Convert to CPU and numpy
            if message.dim() > 1:  # Temporal message
                message_np = message.cpu().numpy()
            else:  # Rate-coded message
                message_np = message.cpu().numpy()
            
            messages.append(message_np.tolist())
            labels.append(int(label))
    
    # Save messages and labels
    message_data = {
        'messages': messages,
        'labels': labels,
        'epoch': epoch,
        'num_samples': num_samples,
        'message_type': 'temporal' if sender_agent.use_temporal else 'rate_coded'
    }
    
    messages_path = os.path.join(save_dir, f"messages_epoch_{epoch}.json")
    with open(messages_path, 'w') as f:
        json.dump(message_data, f, indent=2)
    
    print(f"  Messages saved to {messages_path}")


def save_models(agent_a, agent_b, epoch: int, save_dir: str = "models"):
    """
    Save agent models
    
    Args:
        agent_a: First agent
        agent_b: Second agent
        epoch: Current epoch number
        save_dir: Directory to save models
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Save both agents
    agent_a.save_agent(os.path.join(save_dir, f"agent_a_epoch_{epoch}.pth"))
    agent_b.save_agent(os.path.join(save_dir, f"agent_b_epoch_{epoch}.pth"))
    
    print(f"  Models saved to {save_dir}")


class NeuromorphicExperimentManager:
    """Manager for orchestrating neuromorphic communication experiments."""
    SUBDIRS = ["models", "plots", "metrics", "messages", "logs", "comparisons"]

    def __init__(self, base_dir: str = "experiments"):
        self.base_dir = base_dir
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.experiment_dir = os.path.join(self.base_dir, f"experiment_{self.timestamp}")
        self._create_directories()
        print(f" Experiment Manager initialized at {self.experiment_dir}")

    def _create_directories(self):
        os.makedirs(self.experiment_dir, exist_ok=True)
        for sub in self.SUBDIRS:
            os.makedirs(os.path.join(self.experiment_dir, sub), exist_ok=True)

    def run_experiment(
        self,
        name: str,
        train_fn,
        config: Dict,
        pretrained_path: str,
        plot_fn=None
    ) -> Tuple[Dict, str]:
        """
        Run a single experiment: train and optionally plot.
        - name: experiment name
        - train_fn: function(config, pretrained_path, output_dir) -> (metrics, agent_a, agent_b)
        - plot_fn: optional function(metrics, output_dir) to generate plots
        Returns metrics and path.
        """
        exp_path = os.path.join(self.experiment_dir, name)
        os.makedirs(exp_path, exist_ok=True)
        print(f" Starting experiment '{name}'...")
        start = time.time()

        # Training - unpack the tuple returned by run_training
        metrics, agent_a, agent_b = train_fn(
            pretrained_commsmod_path=pretrained_path,
            **config
        )

        duration = time.time() - start
        self._save_summary(name, config, metrics, duration, exp_path)

        # Plotting
        if plot_fn:
            print(" Generating plots...")
            plot_fn(metrics, exp_path)

        print(f" Experiment '{name}' completed in {duration/60:.1f} min")
        return metrics, exp_path

    def _save_summary(
        self,
        name: str,
        config: Dict,
        metrics: Dict,
        duration: float,
        exp_path: str
    ):
        
        # Safely extract final metrics
        final_metrics = {}
        for k, v in metrics.items():
            if isinstance(v, list) and len(v) > 0:
                final_value = v[-1]
                # Convert to JSON-serializable format
                final_metrics[k] = convert_to_json_serializable(final_value)
            elif v is not None:
                final_metrics[k] = convert_to_json_serializable(v)
        
        summary = {
            'experiment_name': name,
            'timestamp': datetime.now().isoformat(),
            'duration_seconds': float(duration),  # Ensure it's a standard float
            'config': convert_to_json_serializable(config),
            'final_metrics': final_metrics
        }
        
        path = os.path.join(exp_path, 'summary.json')
        with open(path, 'w') as f:
            json.dump(summary, f, indent=2)
        print(f"  📊 Summary saved to {path}")

    def run_multiple(
        self,
        experiments: List[Tuple[str, Dict]],
        pretrained_path: str,
        train_fn,
        plot_fn=None
    ) -> List[Dict]:
        """
        Run multiple experiments and collect metrics.
        """
        results = []
        for name, cfg in experiments:
            metrics, _ = self.run_experiment(
                name, train_fn, cfg, pretrained_path, plot_fn
            )
            results.append(metrics)
        return results

def convert_to_json_serializable(obj):
        """Convert numpy types and other non-serializable objects to JSON-serializable types"""
        if isinstance(obj, dict):
            return {key: convert_to_json_serializable(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_json_serializable(item) for item in obj]
        elif isinstance(obj, tuple):
            return [convert_to_json_serializable(item) for item in obj]
        elif hasattr(obj, 'tolist'):  # numpy arrays
            return obj.tolist()
        elif isinstance(obj, np.integer):  # numpy integer types
            return int(obj)
        elif isinstance(obj, np.floating):  # numpy float types (including float32)
            return float(obj)
        elif isinstance(obj, np.bool_):  # numpy boolean
            return bool(obj)
        elif isinstance(obj, (int, float, str, bool)) or obj is None:
            return obj
        else:
            # For any other type, try to convert to string
            try:
                return str(obj)
            except:
                return f"<non-serializable: {type(obj).__name__}>"

def create_experiment_configs() -> Dict[str, Dict]:
    """
    Create predefined experiment configurations for neuromorphic communication.
    All configurations now use temporal spike patterns (fully neuromorphic).
    
    Returns:
        Dictionary of experiment configurations
    """
    
    base_config = {
        'K': 3,                    # Number of candidates in referential game
        'γ': 0.99,                 # Q-learning discount factor
        'beta': 0.8,               # Membrane potential decay for spiking neurons
        'target_update_freq': 10,  # Target network update frequency
        'role_swap_freq': 25,      # Role swapping frequency
        'device': None             # Auto-detect device
    }
    
    configs = {
        # Fast training for quick testing
        'rate_fast': {
            **base_config,
            'epochs': 10,
            'trials_per_epoch': 50,
            'ε_start': 0.2,
            'ε_final': 0.05,
            'lr': 2e-4,
            'freeze_commsmod': True,
        },
        
        # Standard training configuration
        'rate_standard': {
            **base_config,
            'epochs': 50,
            'trials_per_epoch': 100,
            'ε_start': 0.3,
            'ε_final': 0.05,
            'lr': 1e-4,
            'freeze_commsmod': True,
            'target_update_freq': 20,
        },
        
        # Temporal training (same as standard but different name for compatibility)
        'temporal_standard': {
            **base_config,
            'epochs': 20,
            'trials_per_epoch': 100,
            'ε_start': 0.3,
            'ε_final': 0.05,
            'lr': 1e-4,
            'freeze_commsmod': True,
        },
        
        # Staged training: frozen → unfrozen commsmod
        'staged_training': {
            **base_config,
            'stage2_epochs': 15,    # Frozen commsmod phase
            'stage3_epochs': 10,    # Unfrozen commsmod phase
            'trials_per_epoch': 100,
            'ε_start': 0.3,
            'ε_final': 0.05,
            'lr': 1e-4,
            'freeze_commsmod': True,  # Will be changed during staged training
        }
    }
    
    return configs