#!/usr/bin/env python3
"""
Load CVRP test data
Supports loading .vrp and .sol files from TestingData directory
"""

import os
import sys
import re
from typing import Dict, List, Tuple, Optional
import numpy as np

# Add project root to path for imports
_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(os.path.dirname(_script_dir))  # cvrp
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)
if _heupsro_dir not in sys.path:
    sys.path.insert(0, _heupsro_dir)


def load_vrp_file(vrp_path: str) -> Optional[Dict]:
    """
    加载CVRP格式的.vrp文件
    
    VRP file format:
    - NAME: 实例名称
    - TYPE: CVRP
    - DIMENSION: number of nodes (including depot)
    - CAPACITY: vehicle capacity
    - EDGE_WEIGHT_TYPE: distance type (EUC_2D or EXPLICIT)
    - NODE_COORD_SECTION: node coordinates
    - DEMAND_SECTION: demands
    - DEPOT_SECTION: depot location
    
    Returns:
        Dict with 'depot', 'customers', 'vehicle_capacity', 'instance_name'
    """
    try:
        with open(vrp_path, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]
        
        if not lines:
            return None
        
        # parse header information
        name = None
        dimension = None
        capacity = None
        edge_weight_type = None
        
        coord_section_start = None
        demand_section_start = None
        depot_section_start = None
        
        for i, line in enumerate(lines):
            if line.startswith('NAME'):
                name = line.split(':')[1].strip()
            elif line.startswith('DIMENSION'):
                dimension = int(line.split(':')[1].strip())
            elif line.startswith('CAPACITY'):
                capacity = int(line.split(':')[1].strip())
            elif line.startswith('EDGE_WEIGHT_TYPE'):
                edge_weight_type = line.split(':')[1].strip()
            elif line.startswith('NODE_COORD_SECTION'):
                coord_section_start = i + 1
            elif line.startswith('DEMAND_SECTION'):
                demand_section_start = i + 1
            elif line.startswith('DEPOT_SECTION'):
                depot_section_start = i + 1
        
        if dimension is None or capacity is None:
            return None
        
        # parse node coordinates
        coords = {}
        if coord_section_start is not None:
            for i in range(coord_section_start, len(lines)):
                line = lines[i]
                if line.startswith('DEMAND_SECTION') or line.startswith('DEPOT_SECTION'):
                    break
                parts = line.split()
                if len(parts) >= 3:
                    try:
                        node_id = int(parts[0])
                        x = float(parts[1])
                        y = float(parts[2])
                        coords[node_id] = (x, y)
                    except ValueError:
                        continue
        
        # parse demands
        demands = {}
        if demand_section_start is not None:
            for i in range(demand_section_start, len(lines)):
                line = lines[i]
                if line.startswith('DEPOT_SECTION') or line.startswith('EOF'):
                    break
                parts = line.split()
                if len(parts) >= 2:
                    try:
                        node_id = int(parts[0])
                        demand = int(float(parts[1]))
                        demands[node_id] = demand
                    except ValueError:
                        continue
        
        # parse depot (usually node 1)
        depot_id = 1
        if depot_section_start is not None:
            for i in range(depot_section_start, len(lines)):
                line = lines[i]
                if line.startswith('EOF') or line.startswith('-1'):
                    break
                parts = line.split()
                if parts and parts[0].isdigit():
                    depot_id = int(parts[0])
                    break
        
        # build instance
        if depot_id not in coords:
            return None
        
        depot_coord = coords[depot_id]
        depot_demand = demands.get(depot_id, 0)
        
        # build customer list (excluding depot)
        customers = []
        for node_id in sorted(coords.keys()):
            if node_id == depot_id:
                continue
            if node_id in coords and node_id in demands:
                customers.append({
                    'coords': list(coords[node_id]),
                    'demand': demands[node_id]
                })
        
        instance = {
            'depot': list(depot_coord),
            'customers': customers,
            'vehicle_capacity': capacity,
            'instance_name': name or os.path.splitext(os.path.basename(vrp_path))[0]
        }
        
        return instance
        
    except Exception as e:
        print(f"Skip {vrp_path}: {e}")
        return None


def load_sol_file(sol_path: str) -> Optional[float]:
    """
    load .sol file and extract optimal cost
    
    .sol file format:
    Route #1: 6 32 57 50 17 58 33 55 
    Route #2: 7 10 8 43 29 61 40 20 
    ... 
    Cost 1288
    
    Returns:
        optimal cost value, if not parsable return None
    """
    try:
        with open(sol_path, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]
        
        # find Cost line
        for line in reversed(lines):  # 从后往前查找，Cost通常在最后
            if line.startswith('Cost'):
                parts = line.split()
                if len(parts) >= 2:
                    try:
                        cost = float(parts[1])
                        return cost
                    except ValueError:
                        continue
        
        # if Cost line not found, try to extract from COMMENT (in .vrp file)
        return None
        
    except Exception as e:
        print(f"Skip {sol_path}: {e}")
        return None


def load_test_data_dir(
    test_data_dir: str,
) -> Dict[str, Tuple[List[Dict], List[float]]]:
    """
    load test instances from TestingData directory
    
    Args:
        test_data_dir: TestingData directory path
    
    Returns:
        Dict mapping dataset label to (instances, optimal_values) tuple
    """
    datasets: Dict[str, Tuple[List[Dict], List[float]]] = {}
    
    if not os.path.isdir(test_data_dir):
        return datasets
    
    # group files by dataset
    files_by_dataset: Dict[str, List[Tuple[str, str, Optional[str]]]] = {}  # (vrp_path, instance_name, sol_path)
    
    # traverse all subdirectories
    for root, dirs, files in os.walk(test_data_dir):
        # skip JSON files
        files = [f for f in files if not f.endswith('.json')]
        
        for fname in sorted(files):
            if not fname.lower().endswith('.vrp'):
                continue
            
            file_path = os.path.join(root, fname)
            instance_name = os.path.splitext(fname)[0]
            
            # find corresponding .sol file
            sol_path = os.path.join(root, f"{instance_name}.sol")
            if not os.path.exists(sol_path):
                sol_path = None
            
            # determine dataset label based on path
            rel_path = os.path.relpath(root, test_data_dir)
            
            if rel_path == '.':
                # if file is in root directory, try to extract dataset name from filename (e.g. A-n62-k8.vrp -> A)
                match = re.match(r'^([A-Za-z]+)-', instance_name)
                if match:
                    dataset_label = match.group(1)
                else:
                    dataset_label = "Root"
            else:
                # use subdirectory name as dataset label
                path_parts = rel_path.split(os.sep)
                dataset_label = path_parts[0]  # first level subdirectory name
                if len(path_parts) > 1:
                    dataset_label = f"{path_parts[0]}_{path_parts[1]}"
            
            if dataset_label not in files_by_dataset:
                files_by_dataset[dataset_label] = []
            files_by_dataset[dataset_label].append((file_path, instance_name, sol_path))
    
    # load instances for each dataset
    for dataset_label, file_info_list in files_by_dataset.items():
        instances = []
        optimal_values = []
        
        for file_path, instance_name, sol_path in sorted(file_info_list):
            # load .vrp file
            inst = load_vrp_file(file_path)
            if inst is None:
                continue
            
            inst['instance_name'] = instance_name
            inst['file_path'] = file_path
            instances.append(inst)
            
            # load .sol file to get optimal value
            optimal_value = None
            if sol_path and os.path.exists(sol_path):
                optimal_value = load_sol_file(sol_path)
            
            optimal_values.append(optimal_value)
        
        if instances:
            datasets[dataset_label] = (instances, optimal_values)
            num_with_opt = sum(1 for opt in optimal_values if opt is not None)
            print(f"Loaded dataset '{dataset_label}': {len(instances)} instances, {num_with_opt} with optimal values")
    
    return datasets


def load_all_test_data(
    test_data_dir: Optional[str] = None,
) -> Dict[str, Tuple[List[Dict], List[float]]]:
    """
    load all available test data
    
    Args:
        test_data_dir: TestingData directory path
    
    Returns:
        Dict mapping dataset label to (instances, optimal_values) tuple
    """
    datasets_dict: Dict[str, Tuple[List[Dict], List[float]]] = {}
    
    # load data from TestingData directory
    if test_data_dir and os.path.exists(test_data_dir):
        print(f"Loading test data from: {test_data_dir}")
        test_datasets = load_test_data_dir(test_data_dir)
        datasets_dict.update(test_datasets)
        print(f"Loaded {len(test_datasets)} datasets from test_data_dir")
    elif test_data_dir:
        print(f"Warning: test_data_dir does not exist: {test_data_dir}")
    
    return datasets_dict


if __name__ == '__main__':
    # test code
    import argparse
    
    parser = argparse.ArgumentParser(description="Load CVRP test data")
    parser.add_argument('--test_data_dir', type=str, 
                       default=os.path.join(os.path.dirname(__file__), 'TestingData'),
                       help='TestingData directory path')
    
    args = parser.parse_args()
    
    datasets = load_all_test_data(test_data_dir=args.test_data_dir)
    
    print(f"\n loaded {len(datasets)} datasets:")
    for label, (instances, optimal_values) in datasets.items():
        num_with_opt = sum(1 for opt in optimal_values if opt is not None)
        print(f"  {label}: {len(instances)} instances, {num_with_opt} with optimal values")

