"""
Evaluator for MoE (Mixture of Experts) placement optimization
"""

import asyncio
import importlib.util
import numpy as np
import time
import os
import subprocess
import tempfile
import traceback
import sys
import pickle
import networkx as nx
import math
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any, Optional, Set
from pathlib import Path
import yaml
import argparse
import torch

from openevolve.evaluation_result import EvaluationResult

from activations_dataset import DeterministicActivations, NoisyDeterministicActivations, DeepSeekBasedActivations

from topologies import populate_matrix_dists, prepare_network_topology

def validate_moe_placement(expert_placements, layer_placements, server_expert_count, 
                          server_layer_count, distance_matrix,
                          num_layers=32, experts_per_layer=32, num_servers=32,
                          max_experts_per_server=32, max_layers_per_server=2,
                          max_layer_experts_per_server=4):
    """
    Validate that the MoE placement respects all constraints
    
    Args:
        max_layer_experts_per_server: Maximum number of experts from the same MoE layer 
                                    that can be placed on the same server (to prevent memory issues)
    """
    try:
        # Check basic structure
        if not isinstance(expert_placements, list) or not isinstance(layer_placements, list):
            return False, "Invalid placement structure: must be lists"
            
        if not isinstance(server_expert_count, dict) or not isinstance(server_layer_count, dict):
            return False, "Invalid server count structure: must be dictionaries"
            
        # Check capacity constraints
        for server_id, expert_count in server_expert_count.items():
            if expert_count > max_experts_per_server:
                return False, f"Server {server_id} has {expert_count} experts, exceeding limit of {max_experts_per_server}"
                
        for server_id, layer_count in server_layer_count.items():
            if layer_count > max_layers_per_server:
                return False, f"Server {server_id} has {layer_count} layers, exceeding limit of {max_layers_per_server}"
        
        # Check that no server has too many experts from the same MoE layer
        # Group experts by server and layer
        server_layer_expert_count = {}
        for expert in expert_placements:
            server_id = expert['server_id']
            layer_id = expert['layer_id']
            key = (server_id, layer_id)
            server_layer_expert_count[key] = server_layer_expert_count.get(key, 0) + 1
        
        # Check the constraint
        for (server_id, layer_id), count in server_layer_expert_count.items():
            if count > max_layer_experts_per_server:
                return False, f"Server {server_id} has {count} experts from layer {layer_id}, exceeding limit of {max_layer_experts_per_server}"
        
        # Check that all layers are present
        layer_ids = set(layer['layer_id'] for layer in layer_placements)
        expected_layer_ids = set(range(num_layers))
        if layer_ids != expected_layer_ids:
            return False, f"Missing or extra layers: got {layer_ids}, expected {expected_layer_ids}"
            
        # Count MoE layers and experts
        moe_layers = [layer for layer in layer_placements if layer['layer_type'] == 'moe']
        total_experts_expected = len(moe_layers) * experts_per_layer
        
        if len(expert_placements) != total_experts_expected:
            return False, f"Incorrect number of experts: got {len(expert_placements)}, expected {total_experts_expected}"
            
        # Check that all experts have valid server assignments
        for expert in expert_placements:
            server_id = expert['server_id']
            if not (0 <= server_id < num_servers):
                return False, f"Expert {expert['expert_id']} assigned to invalid server {server_id}"
                
        # Check that all layers have valid server assignments
        for layer in layer_placements:
            server_id = layer['server_id']
            if not (0 <= server_id < num_servers):
                return False, f"Layer {layer['layer_id']} assigned to invalid server {server_id}"
                
        return True, "Valid"
        
    except Exception as e:
        return False, f"Validation error: {str(e)}"


def calculate_forward_pass_hops(expert_placements, layer_placements, distance_matrix, start_server=0):
    """
    Calculate total communication hops for a forward pass through the model
    """
    total_hops = 0
    current_server = start_server
    
    # Sort layers by layer_id to process in order
    sorted_layers = sorted(layer_placements, key=lambda x: x['layer_id'])
    
    for layer in sorted_layers:
        if layer['layer_type'] == 'attention':
            # Move to attention layer server if different
            layer_server = layer['server_id']
            if current_server != layer_server:
                hops = int(distance_matrix[current_server][layer_server])
                total_hops += hops
                current_server = layer_server
                
        elif layer['layer_type'] == 'moe':
            # MoE layer: dispatch -> experts -> collect
            dispatch_server = layer['dispatch_server']
            collect_server = layer['collect_server']
            
            # Move to dispatch server
            if current_server != dispatch_server:
                hops = int(distance_matrix[current_server][dispatch_server])
                total_hops += hops
                current_server = dispatch_server
            
            # Find experts for this layer
            layer_experts = [e for e in expert_placements if e['layer_id'] == layer['layer_id']]
            # print("layer_experts: ", len(layer_experts))
            # Dispatch to experts (parallel, so take max distance)
            max_dispatch_hops = 0
            for expert in layer_experts:
                expert_server = expert['server_id']
                hops = int(distance_matrix[dispatch_server][expert_server])
                # max_dispatch_hops = max(max_dispatch_hops, hops)
                total_hops += hops
            # total_hops += max_dispatch_hops
            
            # Collect from experts (parallel, so take max distance)  
            max_collect_hops = 0
            for expert in layer_experts:
                expert_server = expert['server_id']
                hops = int(distance_matrix[expert_server][collect_server])
                # max_collect_hops = max(max_collect_hops, hops)
                total_hops += hops
            # total_hops += max_collect_hops
            
            current_server = collect_server
    
    return total_hops

from tempfile import NamedTemporaryFile
from contextlib import contextmanager

@contextmanager
def get_file(file_path=None, mode="w+"):
    """
    If file_path is given, open that file; otherwise create a NamedTemporaryFile.
    Yields a file-like object open in the given mode.
    """
    if file_path:
        f = open(file_path, mode)
        try:
            yield f
        finally:
            f.close()
    else:
        # delete=True so the temp file is removed as soon as it's closed
        with NamedTemporaryFile(mode=mode, delete=True) as tmp:
            yield tmp

def calculate_dataset_mean_hops(expert_placements, layer_placements, 
                               activation_dataset, distance_matrix, start_server=0):
    """
    Calculate mean hops across all forward passes in an activation dataset.
    
    Args:
        expert_placements: List of expert placement dictionaries
        layer_placements: List of layer placement dictionaries  
        activation_dataset: ActivationDataset containing multiple forward passes
        distance_matrix: Distance matrix between servers
        start_server: Starting server ID
        
    Returns:
        Tuple of (mean_hops, all_hops_list)
    """
    all_hops = []
    total_tokens_processed = 0
    
    # Create lookup dictionary for faster access
    layer_lookup = {layer['layer_id']: layer for layer in layer_placements}
    # print("layer_lookup: ", layer_lookup)
    
    # Iterate through all forward passes in the dataset
    print("activation_dataset: ", len(activation_dataset))
    for forward_pass_sequence in activation_dataset:
        total_hops = 0
        current_server = start_server
        
        for layer_info in forward_pass_sequence:
            layer_id = layer_info['layer_id']
            
            # Get the full layer info from placements
            if layer_id not in layer_lookup:
                raise ValueError(f"Layer {layer_id} not found in layer_placements")
                
            layer = layer_lookup[layer_id]
            layer_type = layer['layer_type']
            
            if layer_type == 'attention':
                # Move to attention layer server if different
                layer_server = layer['server_id']
                if current_server != layer_server:
                    hops = int(distance_matrix[current_server][layer_server])
                    total_hops += hops
                    current_server = layer_server
                    
            elif layer_type == 'moe':
                activated_experts = set(layer_info['activated_experts'])
                
                # MoE layer: dispatch -> experts -> collect
                dispatch_server = layer['dispatch_server']
                collect_server = layer['collect_server']
                
                # Move to dispatch server
                if current_server != dispatch_server:
                    hops = int(distance_matrix[current_server][dispatch_server])
                    total_hops += hops
                    current_server = dispatch_server
                
                # Find experts for this layer
                layer_expert_placements = {e['expert_id']: e['server_id'] for e in expert_placements if e['expert_id'] in activated_experts and e['layer_id'] == layer_id}
                layer_experts = [e for e in layer_info['activated_experts']]
                # print("layer_experts: ", len(layer_experts))
                
                # Dispatch to experts
                for expert in layer_experts:
                    # expert_server = expert['server_id']
                    expert_server = layer_expert_placements[expert]
                    hops = int(distance_matrix[dispatch_server][expert_server])
                    total_hops += hops
                
                # Collect from experts
                for expert in layer_experts:
                    # expert_server = expert['server_id']
                    expert_server = layer_expert_placements[expert]
                    hops = int(distance_matrix[expert_server][collect_server])
                    total_hops += hops
                
                current_server = collect_server
        
        tokens_in_batch = (len(layer_info['activated_experts']) // layer_info["top_k"])
        all_hops.append(total_hops / tokens_in_batch) # per token hops

    mean_hops = np.mean(all_hops) if all_hops else 0.0
    return mean_hops, all_hops

def evaluate(config_path: str, program_path: str, save_placement_to_file: str = None) -> EvaluationResult:
    """
    Evaluate the MoE placement program
    """
    res = asyncio.run(evaluate_impl(config_path, program_path, save_placement_to_file))
    if res.metrics['validity'] == 0.0:
        return res.artifacts
    else:
        return res.metrics


async def evaluate_impl(config_path: str, program_path: str, save_placement_to_file: str = None) -> EvaluationResult:
    """
    Evaluate the MoE placement program
    """
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    
    # Target parameters for evaluation
    NUM_LAYERS = 2 * config['num_layers']
    EXPERTS_PER_LAYER = config['num_experts']
    NUM_SERVERS = config['num_servers']
    MAX_EXPERTS_PER_SERVER = config['max_experts_per_server']
    MAX_LAYERS_PER_SERVER = config['max_layers_per_server']
    MAX_LAYER_EXPERTS_PER_SERVER = config['max_layer_experts_per_server']
    TOPOLOGY_TYPE = config['topology_type']
    NUM_NODES_PER_LEAF = config['num_nodes_per_leaf']
    NUM_GPUS_PER_SERVER = config['num_gpus_per_server']
    TEST_ACTIVATIONS = config.get('test_activations_path', None)
    TRAIN_PER_LAYER_STATS_PATH = config.get('train_per_layer_stats_path', None)
    
    # Target: minimize communication hops (lower is better)
    TARGET_HOPS = 1000000
    
    try:
        start_time = time.time()
        
        # Prepare network topology
        NUM_LEAD_SWITCHES = NUM_SERVERS//(NUM_NODES_PER_LEAF*NUM_GPUS_PER_SERVER)
        distance_matrix = prepare_network_topology(NUM_LEAD_SWITCHES, TOPOLOGY_TYPE)
        if NUM_NODES_PER_LEAF > 1 or NUM_GPUS_PER_SERVER > 1:
            distance_matrix = populate_matrix_dists(distance_matrix, NUM_GPUS_PER_SERVER, NUM_NODES_PER_LEAF)
        
        neighbor_info = None
        
        STATS_ARG_LINE_SERIALIZED = ""
        if TRAIN_PER_LAYER_STATS_PATH is not None and TRAIN_PER_LAYER_STATS_PATH != "None" and "load_aware" in program_path:
            PER_LAYER_STATS = torch.load(TRAIN_PER_LAYER_STATS_PATH)
            STATS_ARG_LINE_SERIALIZED = f"per_layer_stats=np.array({PER_LAYER_STATS.tolist()}),"
        
        # Create a temporary file to execute the program
        # with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
        # with get_file(file_path=save_placement_to_file, mode="w+") as temp_file:
        with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
            if save_placement_to_file is not None:
                temp_file_to_use = save_placement_to_file
                # with open(f"{temp_file_to_use}.results", "w") as f:
                #     pass
            else:
                temp_file_to_use = temp_file.name

            # Write a script that imports and runs the MoE placement
            script = f"""
import sys
import numpy as np
import os
import pickle
import traceback
import time

# Add the directory to sys.path
sys.path.insert(0, os.path.dirname('{program_path}'))

try:
    # Import the program
    spec = __import__('importlib.util').util.spec_from_file_location("program", '{program_path}')
    program = __import__('importlib.util').util.module_from_spec(spec)
    spec.loader.exec_module(program)
    
    # Prepare inputs
    distance_matrix = np.array({distance_matrix.tolist()})
    neighbor_info = {neighbor_info}
    
    start_time = time.time()
    
    # Run the placement function
    results = program.construct_moe_placement(
        distance_matrix=distance_matrix,
        neighbor_info=neighbor_info,
        {STATS_ARG_LINE_SERIALIZED}
        num_layers={NUM_LAYERS},
        experts_per_layer={EXPERTS_PER_LAYER},
        max_experts_per_server={MAX_EXPERTS_PER_SERVER},
        max_layers_per_server={MAX_LAYERS_PER_SERVER},
        max_layer_experts_per_server={MAX_LAYER_EXPERTS_PER_SERVER},
        random_seed=42
    )
    
    end_time = time.time()
    print(f"Time taken for topo {TOPOLOGY_TYPE} program {program_path} seconds: ", end_time - start_time)

    # Save results to a file
    with open('{temp_file_to_use}.results', 'wb') as f:
        pickle.dump(results, f)
    
except Exception as e:
    raise e
    # If an error occurs, save the error instead
    with open('{temp_file_to_use}.results', 'wb') as f:
        pickle.dump({{'error': str(e)}}, f)
"""
            temp_file.write(script.encode())
            temp_file_path = temp_file.name

        results_path = f"{temp_file_to_use}.results"

        try:
            # Run the script with timeout
            process = await asyncio.create_subprocess_exec(
                sys.executable, temp_file_path,
                stdout=None,  # Directly output to user
                stderr=None   # Directly output to user
            )

            try:
                stdout, stderr = await asyncio.wait_for(
                    process.communicate(), timeout=1000000.0
                )

                # Load the results
                if os.path.exists(results_path):
                    with open(results_path, "rb") as f:
                        results = pickle.load(f)

                    # Check if an error was returned
                    if "error" in results:
                        
                        return EvaluationResult(
                            metrics={
                                "combined_score": 0.0,
                                "validity": 0.0,
                                "total_hops": 999999.0,
                                "efficiency_ratio": 0.0,
                            },
                            artifacts={"error": f"Program execution failed: {results['error']}"}
                        )

                    # Extract results
                    expert_placements, layer_placements, server_expert_count, server_layer_count = results
                    
                    # Validate the placement
                    valid, validation_msg = validate_moe_placement(
                        expert_placements, layer_placements, server_expert_count, 
                        server_layer_count, distance_matrix,
                        NUM_LAYERS, EXPERTS_PER_LAYER, NUM_SERVERS,
                        MAX_EXPERTS_PER_SERVER, MAX_LAYERS_PER_SERVER,
                        MAX_LAYER_EXPERTS_PER_SERVER
                    )
                    
                    # Calculate communication hops if valid
                    if valid:
                        USE_DATASET = True
                        dataset = None
                        if USE_DATASET:
                            # Create a standard deterministic dataset for evaluation
                            if TEST_ACTIVATIONS is None:
                                dataset = DeterministicActivations(
                                    num_requests=100, num_layers=NUM_LAYERS, 
                                    experts_per_layer=EXPERTS_PER_LAYER, experts_per_request=4)
                            else:
                                dataset = DeepSeekBasedActivations(
                                    path=TEST_ACTIVATIONS,
                                    substitute_start_value=1 if "16b" in TEST_ACTIVATIONS else 0
                                )
                            
                            # Calculate mean hops across all forward passes in the dataset
                            mean_hops, all_hops = calculate_dataset_mean_hops(
                                expert_placements, layer_placements, dataset, distance_matrix
                            )
                            # print("mean_hops, all_hops: ", mean_hops, all_hops)
                            
                            total_hops = mean_hops  # Use mean hops as the main metric
                        else:
                            total_hops = calculate_forward_pass_hops(expert_placements, layer_placements, distance_matrix)
                        
                        # Calculate efficiency metrics
                        efficiency_ratio = TARGET_HOPS / total_hops if total_hops > 0 else 0.0
                        servers_used = len(set(e['server_id'] for e in expert_placements))
                        server_utilization = servers_used / NUM_SERVERS
                        
                        # Combined score (higher is better)
                        combined_score = efficiency_ratio * (1.0 if valid else 0.0)
                        
                    else:
                        total_hops = float('inf')
                        efficiency_ratio = 0.0
                        server_utilization = 0.0
                        combined_score = 0.0
                        mean_hops = 0.0
                        all_hops = []
                        dataset = None
                    
                    end_time = time.time()
                    eval_time = end_time - start_time
                    
                    eval_result = EvaluationResult(
                        metrics={
                            "combined_score": float(combined_score),
                            "total_hops": float(total_hops) if total_hops != float('inf') else 999999.0,
                            "efficiency_ratio": float(efficiency_ratio), 
                            "validity": 1.0 if valid else 0.0,
                            "eval_time": float(eval_time),
                            "server_utilization": float(server_utilization) if valid else 0.0,
                        },
                        artifacts={
                            "validation_message": validation_msg,
                            "num_expert_placements": len(expert_placements),
                            "num_layer_placements": len(layer_placements),
                        }
                    )
                    
                    if dataset is not None:
                        # eval_result.metrics["all_hops"] = all_hops
                        eval_result.metrics["dataset_summary"] = {
                            "dataset_type": dataset.__class__.__name__ if valid else "None",
                            "num_forward_passes": len(dataset) if valid else 0,
                            "hops_std": float(np.std(all_hops)) if valid and all_hops else 0.0,
                            "mean_hops": float(mean_hops) if valid else 0.0,
                            "min_hops": float(min(all_hops)) if valid and all_hops else 0.0,
                            "max_hops": float(max(all_hops)) if valid and all_hops else 0.0,
                        }
                    return eval_result
                        

                else:
                    return EvaluationResult(
                        metrics={
                            "combined_score": 0.0,
                            "validity": 0.0,
                            "total_hops": 999999.0,
                            "efficiency_ratio": 0.0,
                        },
                        artifacts={"error": "Results file not found"}
                    )

            except asyncio.TimeoutError:
                process.kill()
                await process.wait()
                return EvaluationResult(
                    metrics={
                        "combined_score": 0.0,
                        "validity": 0.0,
                        "total_hops": 999999.0,
                        "efficiency_ratio": 0.0,
                    },
                    artifacts={"error": "Timeout during evaluation"}
                )

        finally:
            # Clean up temporary files
            if os.path.exists(temp_file_path):
                os.unlink(temp_file_path)
            if os.path.exists(results_path) and save_placement_to_file is None:
                os.unlink(results_path)
                
    except Exception as e:
        raise e
        return EvaluationResult(
            metrics={
                "combined_score": 0.0,
                "validity": 0.0,
                "total_hops": 999999.0,
                "efficiency_ratio": 0.0,
            },
            artifacts={"error": str(e), "type": "evaluation_error"}
        )


# For testing
def test_dataset_evaluation():
    """
    Test function to demonstrate the new dataset evaluation functionality.
    """
    print("Testing dataset evaluation functionality...")
    
    # Create sample datasets
    print("\n1. Creating sample datasets:")
    
    # Standard deterministic dataset
    deterministic_dataset = DeterministicActivations(num_requests=5, num_layers=8)
    print(f"   Deterministic dataset: {len(deterministic_dataset)} forward passes")
    
    # Mixed dataset
    mixed_dataset = NoisyDeterministicActivations(
        num_requests=5, num_layers=8
    )
    print(f"   Mixed dataset: {len(mixed_dataset)} forward passes with 70% attention layers")
    
    print("\n2. Sample forward pass sequences:")
    
    # Show first forward pass from each dataset
    print("   Deterministic dataset, pass 0:")
    first_pass = deterministic_dataset.get_activations(0)
    for i, layer_info in enumerate(first_pass[:4]):  # Show first 4 layers
        print(f"     Layer {layer_info['layer_id']}: {layer_info['layer_type']}")
    
    print("   Mixed dataset, pass 0:")
    first_mixed_pass = mixed_dataset.get_activations(0)
    for i, layer_info in enumerate(first_mixed_pass[:4]):  # Show first 4 layers
        print(f"     Layer {layer_info['layer_id']}: {layer_info['layer_type']}")
    
    print("\n3. Iteration over datasets:")
    
    # Test iteration
    print("   First 3 passes from deterministic dataset (layer count only):")
    for i, forward_pass in enumerate(deterministic_dataset):
        if i >= 3:
            break
        attention_count = sum(1 for layer in forward_pass if layer['layer_type'] == 'attention')
        moe_count = sum(1 for layer in forward_pass if layer['layer_type'] == 'moe')
        print(f"     Pass {i}: {attention_count} attention layers, {moe_count} MoE layers")
    
    print("   First 3 passes from mixed dataset (layer count only):")
    for i, forward_pass in enumerate(mixed_dataset):
        if i >= 3:
            break
        attention_count = sum(1 for layer in forward_pass if layer['layer_type'] == 'attention')
        moe_count = sum(1 for layer in forward_pass if layer['layer_type'] == 'moe')
        print(f"     Pass {i}: {attention_count} attention layers, {moe_count} MoE layers")
    
    print("\nSimplified dataset evaluation functionality is working correctly!")


if __name__ == "__main__":
    import sys
    import json
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, required=False, default="configs/default.yaml")
    parser.add_argument("--program_path", type=str, required=False, default="programs/initial_program_simple.py")
    parser.add_argument("--save_placement_to_file", type=str, required=False, default=None)
    parser.add_argument("--test", action="store_true")
    args = parser.parse_args()
    
    print(f"Evaluating: {args.config_path=} {args.program_path=} {args.save_placement_to_file=}")
    
    if args.test:
        test_dataset_evaluation()
    else:
        result = evaluate(args.config_path, args.program_path, args.save_placement_to_file)
        print("=== Output Results ===")
        print(json.dumps(result))
