#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用实验中的 generator pool 和最后的 sigma_g 生成测试数据
只使用 sigma_g > 0.01 的 generator，按照 config 中的参数生成实例
"""

import os
import sys
import json
import argparse
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional

# 添加项目路径
_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(os.path.dirname(_script_dir))  # bp_online
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)
if _heupsro_dir not in sys.path:
    sys.path.insert(0, _heupsro_dir)

# 导入 resume_experiment 模块
from heupsro.experiments.resume_experiment import load_experiment_config


def validate_bp_instance(instance: Dict, capacity: int, num_items: int) -> Optional[Dict]:
    """
    验证并修正 BP 实例，返回修正后的实例或 None（如果无法修正）。
    
    规则：
    - 检查是否包含 'items', 'capacity', 'num_items'
    - 检查 items 是否在 [1, capacity-1] 范围内
    - 检查 num_items 是否匹配
    """
    if not isinstance(instance, dict):
        return None
    
    if 'items' not in instance or 'capacity' not in instance or 'num_items' not in instance:
        return None
    
    items = np.asarray(instance['items'], dtype=float)
    inst_capacity = int(instance['capacity'])
    inst_num_items = int(instance['num_items'])
    
    # 检查数量是否匹配
    if len(items) != num_items or inst_num_items != num_items:
        return None
    
    # 检查 items 范围
    if np.any(items < 1) or np.any(items >= capacity):
        # 尝试修正：clip 到 [1, capacity-1]
        items = np.clip(items, 1, capacity - 1)
    
    # 转换为整数
    items = items.astype(int)
    
    return {
        'items': items,
        'capacity': capacity,
        'num_items': num_items
    }


def generate_instances_from_generator_code(
    gen_code: str, 
    n: int, 
    capacity: int,
    num_items: int,
    debug_mode: bool = False
) -> List[Dict]:
    """从generator代码生成实例。"""
    try:
        import types
        import warnings
        import sys
        import uuid
        
        # 为每个 generator 创建唯一的模块名
        module_name = f"generator_module_{uuid.uuid4().hex[:8]}"
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            gen_mod = types.ModuleType(module_name)
            
        # 创建简单的命名空间，避免 prepare_exec_namespace 中可能的相对导入问题
        exec_namespace = {
            '__builtins__': __builtins__,
            '__name__': '__main__',
            '__package__': None,
            '__file__': '<generated>',
            'numpy': np,
            'np': np,
        }
        
        # 添加 erf 函数（如果可用）
        try:
            from scipy.special import erf as scipy_erf
            exec_namespace['erf'] = scipy_erf
        except ImportError:
            try:
                from math import erf as math_erf
                exec_namespace['erf'] = math_erf
            except ImportError:
                pass
        
        exec_namespace.update(gen_mod.__dict__)
        
        # 将模块添加到 sys.modules 以便导入
        sys.modules[module_name] = gen_mod
        
        try:
            exec(gen_code, exec_namespace)
            gen_mod.__dict__.update(exec_namespace)
        finally:
            # 清理 sys.modules，避免模块名冲突
            if module_name in sys.modules:
                del sys.modules[module_name]
        
        if not hasattr(gen_mod, "generate_instances"):
            if debug_mode:
                print(f"      [GenerateInstances] ERROR: Generator has no generate_instances function")
            return []
        
        seeds = np.arange(n, dtype=int)
        
        instances = []
        for i, seed in enumerate(seeds):
            try:
                single_seed = [seed]
                # Generator now returns List[np.ndarray], each is items array
                items_arrays = gen_mod.generate_instances(single_seed, capacity, num_items)
                if items_arrays and len(items_arrays) > 0:
                    items = items_arrays[0]  # Now directly items array
                    # Assemble into dict format for validation
                    inst = {
                        'items': np.asarray(items, dtype=int),
                        'capacity': capacity,
                        'num_items': num_items
                    }
                    fixed_inst = validate_bp_instance(inst, capacity, num_items)
                    if fixed_inst is not None:
                        instances.append(fixed_inst)
                    elif debug_mode:
                        print(f"      [GenerateInstances] Instance {i} produced invalid BP instance")
            except Exception as e:
                if debug_mode:
                    print(f"      [GenerateInstances] Error generating instance {i}: {e}")
        
        return instances
    except Exception as e:
        if debug_mode:
            import traceback
            print(f"      [GenerateInstances] Error: {e}")
            print(f"      [GenerateInstances] Traceback:")
            traceback.print_exc()
        return []


def save_falkenauer_format(output_path: str, instance: Dict) -> None:
    """保存 BP 实例为 Falkenauer 格式"""
    items = instance['items']
    capacity = instance['capacity']
    num_items = instance['num_items']
    
    with open(output_path, 'w') as f:
        f.write(f"{num_items}\n")
        f.write(f"{capacity}\n")
        for item_size in items:
            f.write(f"{item_size}\n")


def load_pools(experiment_dir: str) -> dict:
    """加载 pools.json"""
    pools_path = os.path.join(experiment_dir, "psro_results", "pools.json")
    if not os.path.exists(pools_path):
        raise FileNotFoundError(f"找不到 pools.json: {pools_path}")
    with open(pools_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def load_sigma_g(experiment_dir: str, round_num: int = None) -> np.ndarray:
    """从 nash_mixture 文件加载 sigma_g"""
    psro_dir = os.path.join(experiment_dir, "psro_results")
    
    if round_num is not None:
        nash_file = os.path.join(psro_dir, f"nash_mixture_round_{round_num}.json")
        if not os.path.exists(nash_file):
            raise FileNotFoundError(f"找不到 nash_mixture_round_{round_num}.json: {nash_file}")
    else:
        # 找最新的 round
        nash_files = [f for f in os.listdir(psro_dir) if f.startswith("nash_mixture_round_") and f.endswith(".json")]
        if not nash_files:
            raise FileNotFoundError(f"找不到任何 nash_mixture_round_*.json 文件在: {psro_dir}")
        nash_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]), reverse=True)
        nash_file = os.path.join(psro_dir, nash_files[0])
        print(f"  使用最新的 nash_mixture 文件: {os.path.basename(nash_file)}")
    
    with open(nash_file, 'r') as f:
        data = json.load(f)
    
    return np.array(data["generator_mixture"]["sigma_g"])


def generate_testing_data(
    experiment_dir: str,
    weight_threshold: float = 0.01,
    round_num: int = None
) -> None:
    """
    使用实验中的 generator 和 sigma_g 生成测试数据
    
    Args:
        experiment_dir: 实验目录路径
        weight_threshold: sigma_g 权重阈值，只使用权重 > threshold 的 generator
        round_num: 从哪个 round 加载 sigma_g（None 表示使用最新的）
    """
    print("=" * 60)
    print("Generator 测试数据生成器 (BP Online)")
    print("=" * 60)
    print(f"实验目录: {experiment_dir}")
    print(f"权重阈值: {weight_threshold}")
    
    # 1. 加载配置
    print("\n[1/4] 加载实验配置...")
    cfg = load_experiment_config(experiment_dir)
    capacity = getattr(cfg, 'capacity', 100)
    num_items = getattr(cfg, 'num_items', 50)
    n_instances = getattr(cfg, 'eval_n_instances', 3)
    
    print(f"  配置参数:")
    print(f"    capacity: {capacity}")
    print(f"    num_items: {num_items}")
    print(f"    n_instances per generator: {n_instances}")
    
    # 2. 加载 pools 和 sigma_g
    print("\n[2/4] 加载 pools 和 sigma_g...")
    pools_data = load_pools(experiment_dir)
    generators = pools_data.get("generators", [])
    sigma_g = load_sigma_g(experiment_dir, round_num)
    
    print(f"  总 generator 数量: {len(generators)}")
    print(f"  sigma_g 长度: {len(sigma_g)}")
    
    # 验证长度
    if len(sigma_g) != len(generators):
        if len(sigma_g) < len(generators):
            sigma_g = np.pad(sigma_g, (0, len(generators) - len(sigma_g)), 'constant')
        else:
            sigma_g = sigma_g[:len(generators)]
        sigma_g = sigma_g / sigma_g.sum()
        print(f"  调整后的 sigma_g 长度: {len(sigma_g)}")
    
    # 3. 过滤权重 > threshold 的 generator
    print(f"\n[3/4] 过滤权重 > {weight_threshold} 的 generator...")
    selected_generators = []
    selected_weights = []
    selected_indices = []
    
    for idx, (gen, weight) in enumerate(zip(generators, sigma_g)):
        if weight > weight_threshold and gen.get("code"):
            selected_generators.append(gen)
            selected_weights.append(weight)
            selected_indices.append(idx)
            print(f"  Generator {idx}: weight={weight:.4f}, program_id={gen.get('program_id', 'unknown')}")
    
    if len(selected_generators) == 0:
        raise ValueError(f"没有找到权重 > {weight_threshold} 的有效 generator")
    
    print(f"  共选择了 {len(selected_generators)} 个 generator")
    
    # 4. 生成测试数据
    print(f"\n[4/4] 生成测试数据...")
    output_dir = os.path.join(experiment_dir, "generator_testing_data")
    os.makedirs(output_dir, exist_ok=True)
    
    total_instances = 0
    for gen_idx, (gen, weight, orig_idx) in enumerate(zip(selected_generators, selected_weights, selected_indices)):
        gen_code = gen["code"]
        gen_id = gen.get("program_id", f"gen_{orig_idx}")
        gen_params = gen.get("params", {})
        
        print(f"\n  处理 Generator {orig_idx} (weight={weight:.4f}, id={gen_id}):")
        
        # 为每个 generator 创建子目录
        gen_dir = os.path.join(output_dir, f"generator_{orig_idx}_{gen_id}")
        os.makedirs(gen_dir, exist_ok=True)
        
        # 生成实例
        instances = generate_instances_from_generator_code(
            gen_code=gen_code,
            n=n_instances,
            capacity=capacity,
            num_items=num_items,
            debug_mode=True
        )
        
        if not instances:
            print(f"    ⚠️  警告: Generator {orig_idx} 未能生成任何实例")
            continue
        
        # 保存实例
        for inst_idx, instance in enumerate(instances):
            filename = f"BP{num_items}_{capacity}_{inst_idx+1}.txt"
            filepath = os.path.join(gen_dir, filename)
            save_falkenauer_format(filepath, instance)
            total_instances += 1
        
        print(f"    ✅ 生成了 {len(instances)} 个实例，保存在: {gen_dir}")
    
    # 5. 保存元数据
    metadata = {
        "experiment_dir": experiment_dir,
        "weight_threshold": weight_threshold,
        "config": {
            "capacity": capacity,
            "num_items": num_items,
            "n_instances_per_generator": n_instances
        },
        "selected_generators": [
            {
                "original_index": idx,
                "program_id": gen.get("program_id", f"gen_{idx}"),
                "weight": float(weight)
            }
            for idx, gen, weight in zip(selected_indices, selected_generators, selected_weights)
        ],
        "total_instances": total_instances
    }
    
    metadata_path = os.path.join(output_dir, "metadata.json")
    with open(metadata_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    print("\n" + "=" * 60)
    print("✅ 完成！")
    print("=" * 60)
    print(f"共生成 {total_instances} 个测试实例")
    print(f"输出目录: {output_dir}")
    print(f"元数据: {metadata_path}")


def main():
    parser = argparse.ArgumentParser(
        description="使用实验中的 generator pool 和最后的 sigma_g 生成测试数据"
    )
    parser.add_argument(
        "experiment_dir",
        type=str,
        help="实验目录路径"
    )
    parser.add_argument(
        "--weight-threshold",
        type=float,
        default=0.01,
        help="sigma_g 权重阈值，只使用权重 > threshold 的 generator (默认: 0.01)"
    )
    parser.add_argument(
        "--round",
        type=int,
        default=None,
        help="从哪个 round 加载 sigma_g (默认: 最新的)"
    )
    
    args = parser.parse_args()
    
    if not os.path.exists(args.experiment_dir):
        raise FileNotFoundError(f"实验目录不存在: {args.experiment_dir}")
    
    generate_testing_data(
        experiment_dir=args.experiment_dir,
        weight_threshold=args.weight_threshold,
        round_num=args.round
    )


if __name__ == "__main__":
    main()














