#!/usr/bin/env python3
"""
加载BP问题的测试数据
支持多种数据源：
1. TestingData目录中的文件（Heavy-Tail, Mixture, Weibull, Falkenauer）
2. 从lower_bounds.json加载lower bounds（如果存在）
"""

import os
import sys
import json
import re
from typing import Dict, List, Tuple, Optional

# Add project root to path for imports
_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(os.path.dirname(_script_dir))  # bp_online
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)
if _heupsro_dir not in sys.path:
    sys.path.insert(0, _heupsro_dir)

from heupsro.problems.bp_online.oracle import create_bp_online_oracle


def load_falkenauer_txt(txt_path: str) -> Optional[Dict]:
    """
    Load a Falkenauer format .txt file.
    
    Format:
    - Line 1: num_items
    - Line 2: capacity
    - Lines 3+: item sizes (one per line)
    """
    try:
        with open(txt_path, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]
            if len(lines) < 2:
                return None
            num_items = int(lines[0])
            capacity = int(lines[1])
            items = [int(float(line)) for line in lines[2:2+num_items] if line]
            if len(items) != num_items:
                return None
            return {
                'items': items,
                'capacity': capacity,
                'num_items': num_items
            }
    except Exception as e:
        print(f"Skip {txt_path}: {e}")
        return None


def load_lower_bounds_from_json(lb_file: str) -> Dict[str, float]:
    """
    从lower_bounds.json文件加载lower bounds
    
    Args:
        lb_file: lower_bounds.json文件路径
    
    Returns:
        Dict mapping filename to lower bound
    """
    if not os.path.exists(lb_file):
        return {}
    
    try:
        with open(lb_file, 'r') as f:
            return json.load(f)
    except Exception as e:
        print(f"Warning: Failed to load lower bounds from {lb_file}: {e}")
        return {}


def load_test_data_dir(
    test_data_dir: str,
    use_lower_bounds_json: bool = True,
    only_ordered_data_outputs: bool = False
) -> Dict[str, Tuple[List[Dict], List[float]]]:
    """
    从TestingData目录加载测试实例
    
    支持的文件类型：
    - Heavy-Tail分布：HeavyTail_alpha{alpha}_{n_items}_{i:02d}.txt
    - Weibull分布：Weibull_shape{shape}_scale{scale}_{n_items}_{i:02d}.txt
    - Mixture分布：Mixture{1|2|3}_{n_items}_{i:02d}.txt
    - Falkenauer文件：Falkenauer_t{size}_*.txt 或 Falkenauer_u{size}_*.txt
    - ordered_data/outputs中的JSON文件：dataset_*.json
    
    Args:
        test_data_dir: TestingData目录路径
        use_lower_bounds_json: 是否使用lower_bounds.json中的lower bounds
        only_ordered_data_outputs: 如果为True，只加载ordered_data/outputs目录中的数据
    
    Returns:
        Dict mapping dataset label to (instances, lower_bounds) tuple
    """
    datasets: Dict[str, Tuple[List[Dict], List[float]]] = {}
    oracle = create_bp_online_oracle(oracle_type='lb')
    
    if not os.path.isdir(test_data_dir):
        return datasets
    
    # 如果只使用ordered_data/outputs，调整搜索目录
    # 排除group_B，只加载group_A和group_C
    search_dirs = []
    if only_ordered_data_outputs:
        ordered_outputs_dir = os.path.join(test_data_dir, 'ordered_data', 'outputs')
        if os.path.isdir(ordered_outputs_dir):
            # 只加载group_A和group_C，排除group_B
            for group_name in ['group_A', 'group_C']:
                group_dir = os.path.join(ordered_outputs_dir, group_name)
                if os.path.isdir(group_dir):
                    search_dirs.append(group_dir)
            print(f"Loading data from ordered_data/outputs (excluding group_B): {search_dirs}")
        
        # 同时加载Synthetic/Weibull目录
        weibull_dir = os.path.join(test_data_dir, 'Synthetic')
        if os.path.isdir(weibull_dir):
            search_dirs.append(weibull_dir)
            print(f"Also loading Weibull data from: {weibull_dir}")
    else:
        search_dirs = [test_data_dir]
    
    # 尝试加载lower_bounds.json（从原始test_data_dir）
    lb_file = os.path.join(test_data_dir, 'lower_bounds.json')
    lower_bounds_dict = {}
    if use_lower_bounds_json and os.path.exists(lb_file):
        lower_bounds_dict = load_lower_bounds_from_json(lb_file)
        print(f"Loaded {len(lower_bounds_dict)} lower bounds from {lb_file}")
    
    # 按数据集分组文件
    files_by_dataset: Dict[str, List[Tuple[str, str]]] = {}  # (file_path, instance_name)
    
    # 遍历所有搜索目录
    for search_dir in search_dirs:
        for root, dirs, files in os.walk(search_dir):
            # 如果遍历到group_B目录，跳过
            if 'group_B' in root:
                continue
            
            # 如果遍历Synthetic目录，只处理Weibull文件
            if 'Synthetic' in root:
                # 只处理Weibull开头的文件
                files = [f for f in files if f.startswith('Weibull_')]
            
            # 跳过lower_bounds.json和compute_lower_bounds.py等非数据文件
            for fname in sorted(files):
                if not (fname.endswith('.txt') or fname.endswith('.json')):
                    continue
                if fname in ['lower_bounds.json', 'compute_lower_bounds.py']:
                    continue
                
                file_path = os.path.join(root, fname)
                instance_name = os.path.splitext(fname)[0]
                
                # 根据文件名确定数据集标签
                dataset_label = None
                
                # 对于ordered_data的JSON文件，优先从JSON内容中读取dataset和order字段
                if fname.endswith('.json') and fname.startswith('dataset_'):
                    try:
                        with open(file_path, 'r') as f:
                            data = json.load(f)
                            if isinstance(data, dict) and 'dataset' in data:
                                dataset_name = data.get('dataset', '')
                                order = data.get('order', '')
                                if order:
                                    dataset_label = f"{dataset_name}_{order}"
                                else:
                                    dataset_label = dataset_name
                            else:
                                # 如果JSON中没有dataset字段，使用子目录名
                                rel_path = os.path.relpath(root, test_data_dir)
                                if rel_path == '.':
                                    dataset_label = os.path.basename(test_data_dir)
                                else:
                                    dataset_label = os.path.basename(rel_path)
                    except Exception:
                        # 如果读取失败，使用子目录名
                        rel_path = os.path.relpath(root, test_data_dir)
                        if rel_path == '.':
                            dataset_label = os.path.basename(test_data_dir)
                        else:
                            dataset_label = os.path.basename(rel_path)
                
                # Falkenauer文件
                elif 'Falkenauer' in fname:
                    parts = fname.lower().replace('.txt', '').split('_')
                    if len(parts) >= 2:
                        falk_type = parts[1]  # 't' or 'u'
                        # 提取size
                        size_match = re.search(r'(\d+)', parts[1])
                        if size_match:
                            size = size_match.group(1)
                            dataset_label = f"Falkenauer_{falk_type.upper()}_{size}"
                        else:
                            dataset_label = f"Falkenauer_{falk_type.upper()}"
                    else:
                        dataset_label = "Falkenauer"
                
                # Heavy-Tail文件
                elif fname.startswith('HeavyTail_'):
                    # 格式: HeavyTail_alpha{alpha}_{n_items}_{i:02d}.txt
                    match = re.match(r'HeavyTail_alpha([\d.]+)_(\d+)_\d+\.txt', fname)
                    if match:
                        alpha = match.group(1)
                        n_items = match.group(2)
                        dataset_label = f"HeavyTail_alpha{alpha}_{n_items}"
                    else:
                        dataset_label = "HeavyTail"
                
                # Weibull文件（我们自己生成的）
                elif fname.startswith('Weibull_'):
                    # 文件名格式（当前generate_test_data使用的格式）:
                    #   Weibull_shape1p4_scale30p0_{n_items}_{i:02d}.txt
                    # 统一命名格式为 Weibull_{n_items}_{capacity}，与Falkenauer一致
                    match = re.match(r'Weibull_shape[^_]+_scale[^_]+_(\d+)_\d+\.txt', fname)
                    if match:
                        n_items = int(match.group(1))
                        # 读取文件获取capacity（Falkenauer格式：第一行是num_items，第二行是capacity）
                        try:
                            temp_inst = load_falkenauer_txt(file_path)
                            if temp_inst is not None:
                                capacity = temp_inst.get('capacity', n_items)
                                dataset_label = f"Weibull_{n_items}_{capacity}"
                            else:
                                dataset_label = f"Weibull_{n_items}"
                        except Exception:
                            dataset_label = f"Weibull_{n_items}"
                    else:
                        dataset_label = "Weibull"
                
                # Mixture文件
                elif fname.startswith('Mixture'):
                    # 格式: Mixture{1|2|3}_{n_items}_{i:02d}.txt
                    match = re.match(r'Mixture(\d+)_(\d+)_\d+\.txt', fname)
                    if match:
                        mix_num = match.group(1)
                        n_items = match.group(2)
                        dataset_label = f"Mixture{mix_num}_{n_items}"
                    else:
                        dataset_label = "Mixture"
                
                # 其他文件使用子目录名或默认标签
                # 但是要过滤掉csAA、csAB、csBA、csBB等Irnich_BPP数据集
                else:
                    # 检查是否是Irnich_BPP数据集（csAA、csAB、csBA、csBB开头）
                    if instance_name.startswith('csAA') or instance_name.startswith('csAB') or \
                       instance_name.startswith('csBA') or instance_name.startswith('csBB'):
                        continue  # 跳过Irnich_BPP数据集
                    
                    rel_path = os.path.relpath(root, test_data_dir)
                    if rel_path == '.':
                        dataset_label = os.path.basename(test_data_dir)
                    else:
                        dataset_label = os.path.basename(rel_path)
                
                if dataset_label is not None:
                    if dataset_label not in files_by_dataset:
                        files_by_dataset[dataset_label] = []
                    files_by_dataset[dataset_label].append((file_path, instance_name))
    
    # 加载每个数据集的实例
    for dataset_label, file_info_list in files_by_dataset.items():
        instances = []
        optimal_values = []
        
        for file_path, instance_name in sorted(file_info_list):
            # 加载.txt文件（Falkenauer格式）
            if file_path.endswith('.txt'):
                inst = load_falkenauer_txt(file_path)
                if inst is not None:
                    inst['instance_name'] = instance_name
                    instances.append(inst)
                    
                    # 尝试从lower_bounds.json获取，否则使用oracle计算
                    filename = os.path.basename(file_path)
                    if filename in lower_bounds_dict:
                        optimal_value = lower_bounds_dict[filename]
                    else:
                        optimal_value = oracle.solve_exact(inst)
                    optimal_values.append(optimal_value)
            
            # 加载JSON文件
            elif file_path.endswith('.json'):
                try:
                    with open(file_path, 'r') as f:
                        data = json.load(f)
                        if isinstance(data, dict) and 'items' in data:
                            data['instance_name'] = instance_name
                            instances.append(data)
                            filename = os.path.basename(file_path)
                            # 优先使用JSON文件中的lower_bound字段，然后是lower_bounds_dict，最后才计算
                            # 验证lower_bound是否合理（必须>0且不能太小，避免gap异常大）
                            if 'lower_bound' in data and data['lower_bound'] is not None:
                                lb_from_json = float(data['lower_bound'])
                                # 验证：lower_bound应该至少为1（至少需要1个bin）
                                # 同时检查是否合理：至少应该 >= ceil(sum(items)/capacity)，这是理论下界
                                if lb_from_json >= 1.0:
                                    # 计算理论下界进行验证
                                    items_sum = sum(data.get('items', []))
                                    capacity = data.get('capacity', 100)
                                    theoretical_lb = max(1.0, (items_sum + capacity - 1) // capacity) if capacity > 0 else 1.0
                                    # 如果lower_bound明显小于理论下界，说明数据有问题，应该重新计算
                                    if lb_from_json >= theoretical_lb * 0.9:  # 允许10%的误差
                                        optimal_value = lb_from_json
                                    else:
                                        # lower_bound不合理，使用oracle重新计算
                                        if filename in lower_bounds_dict:
                                            optimal_value = lower_bounds_dict[filename]
                                        else:
                                            optimal_value = oracle.solve_exact(data)
                                else:
                                    # lower_bound < 1，不合理，使用oracle重新计算
                                    if filename in lower_bounds_dict:
                                        optimal_value = lower_bounds_dict[filename]
                                    else:
                                        optimal_value = oracle.solve_exact(data)
                            elif filename in lower_bounds_dict:
                                optimal_value = lower_bounds_dict[filename]
                            else:
                                optimal_value = oracle.solve_exact(data)
                            optimal_values.append(optimal_value)
                        elif isinstance(data, list):
                            for idx, inst in enumerate(data):
                                if isinstance(inst, dict) and 'items' in inst:
                                    inst['instance_name'] = f"{instance_name}_{idx}"
                                    instances.append(inst)
                                    filename = f"{os.path.basename(file_path)}_{idx}"
                                    # 优先使用JSON文件中的lower_bound字段，然后是lower_bounds_dict，最后才计算
                                    # 验证lower_bound是否合理（必须>0且不能太小，避免gap异常大）
                                    if 'lower_bound' in inst and inst['lower_bound'] is not None:
                                        lb_from_json = float(inst['lower_bound'])
                                        # 验证：lower_bound应该至少为1（至少需要1个bin）
                                        # 同时检查是否合理：至少应该 >= ceil(sum(items)/capacity)，这是理论下界
                                        if lb_from_json >= 1.0:
                                            # 计算理论下界进行验证
                                            items_sum = sum(inst.get('items', []))
                                            capacity = inst.get('capacity', 100)
                                            theoretical_lb = max(1.0, (items_sum + capacity - 1) // capacity) if capacity > 0 else 1.0
                                            # 如果lower_bound明显小于理论下界，说明数据有问题，应该重新计算
                                            if lb_from_json >= theoretical_lb * 0.9:  # 允许10%的误差
                                                optimal_value = lb_from_json
                                            else:
                                                # lower_bound不合理，使用oracle重新计算
                                                if filename in lower_bounds_dict:
                                                    optimal_value = lower_bounds_dict[filename]
                                                else:
                                                    optimal_value = oracle.solve_exact(inst)
                                        else:
                                            # lower_bound < 1，不合理，使用oracle重新计算
                                            if filename in lower_bounds_dict:
                                                optimal_value = lower_bounds_dict[filename]
                                            else:
                                                optimal_value = oracle.solve_exact(inst)
                                    elif filename in lower_bounds_dict:
                                        optimal_value = lower_bounds_dict[filename]
                                    else:
                                        optimal_value = oracle.solve_exact(inst)
                                    optimal_values.append(optimal_value)
                except Exception as e:
                    print(f"Skip {file_path}: {e}")
        
        if instances:
            datasets[dataset_label] = (instances, optimal_values)
            print(f"Loaded dataset '{dataset_label}': {len(instances)} instances")
    
    return datasets


def load_all_test_data(
    test_data_dir: Optional[str] = None,
    use_lower_bounds_json: bool = True,
    only_ordered_data_outputs: bool = False
) -> Dict[str, Tuple[List[Dict], List[float]]]:
    """
    加载所有可用的测试数据
    
    Args:
        test_data_dir: TestingData目录路径
        use_lower_bounds_json: 是否使用lower_bounds.json中的lower bounds
        only_ordered_data_outputs: 如果为True，只加载ordered_data/outputs目录中的数据
    
    Returns:
        Dict mapping dataset label to (instances, lower_bounds) tuple
    """
    datasets_dict: Dict[str, Tuple[List[Dict], List[float]]] = {}
    
    # 加载TestingData目录中的数据
    if test_data_dir and os.path.exists(test_data_dir):
        if only_ordered_data_outputs:
            print(f"Loading test data from: {test_data_dir}/ordered_data/outputs (only)")
        else:
            print(f"Loading test data from: {test_data_dir}")
        test_datasets = load_test_data_dir(
            test_data_dir, 
            use_lower_bounds_json=use_lower_bounds_json,
            only_ordered_data_outputs=only_ordered_data_outputs
        )
        datasets_dict.update(test_datasets)
        print(f"Loaded {len(test_datasets)} datasets from test_data_dir")
    elif test_data_dir:
        print(f"Warning: test_data_dir does not exist: {test_data_dir}")
    
    return datasets_dict

