import os
from pathlib import Path
import torch

from load_data import load_instances_with_baselines
from utils.utils_for_model import (
    create_parser,
    compute_tsp_tour_length,
    compute_vrp_tour_length,
    load_stage_ckpt,
    run_aug,
    generate_tsp_instance,
)
from utils.utilities import (
    choose_bsz,
    normalize_nodes_to_unit_board,
    load_tsplib_file,
    tsplib_collections,
    parse_tsplib_name,
    load_cvrplib_file,
    cvrplib_collections,
    parse_cvrplib_name,
)
from tsp_env import TSPEnvironment
from tsp_policy_two_stage import TSPStage1Policy, TSPStage2Policy
from vrp_policy import VRPPolicy


def _best_over_augmented(lengths: torch.Tensor, aug_num: int) -> torch.Tensor:
    """Return per-base-instance best lengths across augmented variants."""
    if lengths.numel() % aug_num != 0:
        raise ValueError("Length tensor size must be divisible by aug_num for best-of-aug reduction.")
    base = lengths.numel() // aug_num
    return lengths.view(base, aug_num).min(dim=1).values


def _reduce_augmented_lengths(lengths: torch.Tensor, aug_num: int, use_best: bool = True) -> torch.Tensor:
    """Group augmented lengths back to the base-instance dimension."""
    if lengths.numel() % aug_num != 0:
        raise ValueError("Length tensor size must be divisible by aug_num for reduction.")
    grouped = lengths.view(-1, aug_num)
    if use_best:
        return grouped.min(dim=1).values
    return grouped.mean(dim=1)


def rollout_two_stage(
    env: TSPEnvironment,
    stage1: TSPStage1Policy,
    stage2: TSPStage2Policy,
    k_promising: int,
    deterministic: bool = True,
):
    """Greedy rollout for two-stage policy."""
    log_probs = []
    obs = env.observation()
    for _ in range(env.nb_nodes - 1):
        selected_idx, _, _ = stage1.select_k(obs, k_promising=k_promising, deterministic=deterministic)
        action, logp, _ = stage2.select_action(
            obs,
            selected_global_idx=selected_idx,
            deterministic=deterministic,
        )
        obs, done = env.step(action)
        log_probs.append(logp)
        if done:
            break
    tours = env.get_tour_tensor()
    sum_log_probs = torch.stack(log_probs, dim=1).sum(dim=1)
    return tours, sum_log_probs


def build_args():
    """Build CLI args with sensible defaults for two-stage evaluation."""
    config_dict = {
        'problem': 'tsp',  # tsp (two-stage) or cvrp (single-stage VRP evaluation)
        'bsz': 64,
        'nb_nodes': 50,
        'dim_input_nodes': 2,
        'dim_emb': 128,
        'dim_ff': 512,
        'nb_heads': 8,
        'nb_layers_action_encoder': 2,
        'nb_layers_decoder': 3,
        'batchnorm': False,
        'use_normalization_layer': False,
        'aug': 'mix',
        'aug_num': 16,
        'test_aug_num': 16,
        'data_path': './data/',
        'nb_batch_eval': 100,  # used for synthetic eval mode
        'stage1_ckpt': '',
        'stage2_ckpt': '',
        # VRP-only params
        'num_state_encoder': 2,
        'nb_layers_state_encoder': 2,
        'action_k': 15,
        'state_k': '35,50,65',
        'if_use_local_mask': False,
        'if_agg_whole_graph': False,
        'vrp_ckpt': '',
        'k_promising': 8,
        'deterministic': True,
        'use_best_over_aug': True,
        'use_stage1_action_encoding': 'True',
        'eval_mode': 'dataset',  # dataset | synthetic | tsplib | cvrplib
        'eval_sizes': '100',  # comma-separated list
        'eval_distributions': 'uniform',  # comma-separated list
        'eval_num_instances': -1,  # -1 means all available
    }
    parser, args = create_parser(config_dict)
    args = parser.parse_args(namespace=args)
    args.eval_mode = str(args.eval_mode).lower()
    args.eval_sizes = [int(s) for s in str(args.eval_sizes).split(',') if str(s).strip()]
    args.eval_distributions = [s.strip() for s in str(args.eval_distributions).split(',') if s.strip()]
    if not args.eval_distributions:
        args.eval_distributions = ['uniform']
    if isinstance(args.state_k, str):
        args.state_k = [int(x) for x in args.state_k.split(',') if x.strip()]
    if len(args.state_k) < args.num_state_encoder:
        args.state_k = (args.state_k + args.state_k[-1:])[:args.num_state_encoder]
    else:
        args.state_k = args.state_k[:args.num_state_encoder]
    if isinstance(args.use_stage1_action_encoding, str):
        args.use_stage1_action_encoding = args.use_stage1_action_encoding.lower() not in ('false', '0', 'no')
    else:
        args.use_stage1_action_encoding = bool(args.use_stage1_action_encoding)
    return args


def _require_dataset_root(data_path: str) -> None:
    expected = os.path.join(data_path, 'data_farm')
    if not os.path.isdir(expected):
        raise FileNotFoundError(
            f"Dataset directory '{expected}' is missing. Please unzip/download the dataset and point --data_path to its root."
        )


def _load_policy_state(policy: torch.nn.Module, ckpt_path: str, device: torch.device, expected_stage: str = None) -> None:
    """Load checkpoint with a best-effort fallback."""
    try:
        load_stage_ckpt(policy, ckpt_path, device, expected_stage=expected_stage, strict=False)
        return
    except Exception:
        pass
    ckpt = torch.load(ckpt_path, map_location=device)
    state_dict = ckpt['policy_state_dict'] if isinstance(ckpt, dict) and 'policy_state_dict' in ckpt else ckpt
    policy.load_state_dict(state_dict, strict=False)


@torch.no_grad()
def _evaluate_on_dataset(args, device, stage1, stage2):
    if args.bsz % args.test_aug_num != 0:
        raise ValueError("bsz must be a multiple of test_aug_num for dataset evaluation.")
    base_per_batch = args.bsz // args.test_aug_num
    _require_dataset_root(args.data_path)
    print(f"[Eval2S][Dataset] Root: {args.data_path} | Augmentation: {args.aug} | Reduce: {'best' if args.use_best_over_aug else 'mean'}")

    for size in args.eval_sizes:
        for distribution in args.eval_distributions:
            tsp_instances, _, opt_lens = load_instances_with_baselines(args.data_path, "tsp", size, distribution)
            total_available = tsp_instances.size(0)
            total_target = total_available if args.eval_num_instances < 0 else min(args.eval_num_instances, total_available)
            if total_target == 0:
                print(f"[Eval2S][Dataset][tsp{size}-{distribution}] No instances found, skipping.")
                continue

            tsp_instances = tsp_instances[:total_target].float()
            opt_lens_tensor = torch.tensor(opt_lens[:total_target], device=device, dtype=torch.float)
            gathered_lengths = []
            processed = 0
            while processed < total_target:
                cur_base = min(base_per_batch, total_target - processed)
                coords = tsp_instances[processed:processed + cur_base].to(device)
                x_repeat = coords.unsqueeze(1).repeat((1, args.test_aug_num, 1, 1)).view(
                    cur_base * args.test_aug_num, size, args.dim_input_nodes
                )
                x_aug = run_aug(args.aug, x_repeat, args.test_aug_num)
                env = TSPEnvironment(x_aug)
                tours, _ = rollout_two_stage(
                    env,
                    stage1,
                    stage2,
                    k_promising=args.k_promising,
                    deterministic=args.deterministic,
                )
                lengths = compute_tsp_tour_length(x_repeat, tours)
                base_lengths = _reduce_augmented_lengths(lengths, args.test_aug_num, args.use_best_over_aug)
                gathered_lengths.append(base_lengths.cpu())
                processed += cur_base

            all_lengths = torch.cat(gathered_lengths)
            avg_len = all_lengths.mean().item()
            opt_cpu = opt_lens_tensor[:all_lengths.size(0)].cpu()
            avg_gap = ((all_lengths - opt_cpu) / opt_cpu).mean().item()
            print(
                f"[Eval2S][Dataset][tsp{size}-{distribution}] "
                f"avg_len={avg_len:.4f}, avg_gap={avg_gap * 100:.3f}% over {all_lengths.size(0)} instances "
                f"(best-of-{args.test_aug_num} augments)."
            )


@torch.no_grad()
def _evaluate_on_synthetic(args, device, stage1, stage2):
    total_len = 0.0
    for _ in range(args.nb_batch_eval):
        x_aug, x_repeat = generate_tsp_instance(args, device, if_test=True)
        env = TSPEnvironment(x_aug)
        tours, _ = rollout_two_stage(
            env,
            stage1,
            stage2,
            k_promising=args.k_promising,
            deterministic=args.deterministic,
                )
        lengths = compute_tsp_tour_length(x_repeat, tours)
        if args.use_best_over_aug:
            lengths = _best_over_augmented(lengths, args.test_aug_num)
        total_len += lengths.mean().item()

    avg_len = total_len / args.nb_batch_eval
    print(f"[Eval2S][Synthetic] avg tour length over {args.nb_batch_eval} batches: {avg_len:.4f}")


@torch.no_grad()
def _evaluate_on_tsplib(args, device, stage1, stage2):
    root = Path(args.data_path)
    names = sorted(tsplib_collections.keys(), key=lambda n: parse_tsplib_name(n)[1])
    buckets = {'1-100': [], '101-1000': [], '1001-10000': [], '>10000': []}

    print(f"[Eval2S][TSPLIB] Root: {root} | Augmentation: {args.aug} | Reduce: {'best' if args.use_best_over_aug else 'mean'}")
    for idx, name in enumerate(names):
        opt_len = tsplib_collections[name]
        instance, _ = load_tsplib_file(root, name)
        size = instance.size(0)
        base_bsz = choose_bsz(size)
        total_bsz = base_bsz * args.test_aug_num

        coords_norm = normalize_nodes_to_unit_board(instance).float()
        coords_norm_rep = coords_norm.unsqueeze(0).repeat((total_bsz, 1, 1)).to(device)
        coords_orig_rep = instance.float().unsqueeze(0).repeat((total_bsz, 1, 1)).to(device)

        x_aug = run_aug(args.aug, coords_norm_rep, args.test_aug_num)
        env = TSPEnvironment(x_aug)
        tours, _ = rollout_two_stage(
            env,
            stage1,
            stage2,
            k_promising=args.k_promising,
            deterministic=args.deterministic,
        )
        lengths = compute_tsp_tour_length(coords_orig_rep, tours)
        base_lengths = _reduce_augmented_lengths(lengths, args.test_aug_num, args.use_best_over_aug)
        best_len = base_lengths.min().item()
        gap = best_len / opt_len - 1

        if size <= 100:
            buckets['1-100'].append(gap)
        elif size <= 1000:
            buckets['101-1000'].append(gap)
        elif size <= 10000:
            buckets['1001-10000'].append(gap)
        else:
            buckets['>10000'].append(gap)

        print(f"[Eval2S][TSPLIB][{idx:03d}] {name:12s} | size={size:5d} | len={best_len:.3f} | gap={gap * 100:.3f}%")

    def _avg(lst): return sum(lst) / len(lst) if lst else float('nan')
    print("[Eval2S][TSPLIB] Summary gaps (%): "
          f"1-100={_avg(buckets['1-100'])*100:.3f}, "
          f"101-1000={_avg(buckets['101-1000'])*100:.3f}, "
          f"1001-10000={_avg(buckets['1001-10000'])*100:.3f}, "
          f">10000={_avg(buckets['>10000'])*100:.3f}")


@torch.no_grad()
def _evaluate_on_cvrplib(args, device, vrp_model: VRPPolicy):
    root = Path(args.data_path)
    names = sorted(cvrplib_collections.keys(), key=lambda n: parse_cvrplib_name(n)[1])
    buckets = {'1-100': [], '101-200': [], '201-500': [], '>500': []}

    print(f"[Eval2S][CVRPLIB] Root: {root} | Augmentation: {args.aug} | Reduce: {'best' if args.use_best_over_aug else 'mean'}")
    for idx, name in enumerate(names):
        opt_len = cvrplib_collections[name]
        _, size = parse_cvrplib_name(name)
        depot, nodes, demands, capacity, _ = load_cvrplib_file(root, name)
        coords_orig = torch.cat((nodes, depot.unsqueeze(0)), dim=0).float()

        coords_norm = normalize_nodes_to_unit_board(coords_orig)
        base_bsz = choose_bsz(size)
        total_bsz = base_bsz * args.test_aug_num

        coords_norm_rep = coords_norm.unsqueeze(0).repeat((total_bsz, 1, 1)).to(device)
        coords_orig_rep = coords_orig.unsqueeze(0).repeat((total_bsz, 1, 1)).to(device)
        demand_rep = demands.to(device).unsqueeze(0).repeat((total_bsz, 1))

        x_aug = run_aug(args.aug, coords_norm_rep, args.test_aug_num)
        depot_aug = x_aug[:, -1, :]
        nodes_aug = x_aug[:, :-1, :]
        input_aug = {'loc': nodes_aug, 'demand': demand_rep, 'depot': depot_aug}

        tours, _ = vrp_model(
            input_aug,
            args.action_k,
            args.state_k,
            capacity.item(),
            'cvrp',
            choice_deterministic=args.deterministic,
            if_use_local_mask=args.if_use_local_mask,
        )
        lengths = compute_vrp_tour_length(coords_orig_rep, tours)
        base_lengths = _reduce_augmented_lengths(lengths, args.test_aug_num, args.use_best_over_aug)
        best_len = base_lengths.min().item()
        gap = best_len / opt_len - 1

        if size <= 100:
            buckets['1-100'].append(gap)
        elif size <= 200:
            buckets['101-200'].append(gap)
        elif size <= 500:
            buckets['201-500'].append(gap)
        else:
            buckets['>500'].append(gap)

        print(f"[Eval2S][CVRPLIB][{idx:03d}] {name:12s} | size={size:5d} | len={best_len:.3f} | gap={gap * 100:.3f}%")

    def _avg(lst): return sum(lst) / len(lst) if lst else float('nan')
    print("[Eval2S][CVRPLIB] Summary gaps (%): "
          f"1-100={_avg(buckets['1-100'])*100:.3f}, "
          f"101-200={_avg(buckets['101-200'])*100:.3f}, "
          f"201-500={_avg(buckets['201-500'])*100:.3f}, "
          f">500={_avg(buckets['>500'])*100:.3f}")


@torch.no_grad()
def evaluate_two_stage(args, device):
    problem = str(args.problem).lower()

    if problem == 'cvrp':
        ckpt_path = args.vrp_ckpt or args.stage1_ckpt
        if not ckpt_path:
            raise ValueError("Please provide --vrp_ckpt (or --stage1_ckpt) for CVRPLIB evaluation.")
        if not os.path.isfile(ckpt_path):
            raise FileNotFoundError(f"VRP checkpoint not found: {ckpt_path}")
        vrp_model = VRPPolicy(
            args.dim_input_nodes,
            args.dim_emb,
            args.dim_ff,
            args.num_state_encoder,
            args.nb_layers_state_encoder,
            args.nb_layers_action_encoder,
            args.nb_layers_decoder,
            args.nb_heads,
            batchnorm=args.batchnorm,
            if_agg_whole_graph=args.if_agg_whole_graph,
        ).to(device)
        _load_policy_state(vrp_model, ckpt_path, device, expected_stage='pretrain')
        vrp_model.eval()
        if args.eval_mode != 'cvrplib':
            raise ValueError("For CVRP, only eval_mode='cvrplib' is supported.")
        _evaluate_on_cvrplib(args, device, vrp_model)
        return

    # Default: TSP two-stage evaluation
    if not args.stage1_ckpt or not args.stage2_ckpt:
        raise ValueError("Please provide both --stage1_ckpt and --stage2_ckpt paths.")
    if not os.path.isfile(args.stage1_ckpt):
        raise FileNotFoundError(f"Stage 1 checkpoint not found: {args.stage1_ckpt}")
    if not os.path.isfile(args.stage2_ckpt):
        raise FileNotFoundError(f"Stage 2 checkpoint not found: {args.stage2_ckpt}")

    stage1 = TSPStage1Policy(args).to(device)
    stage2 = TSPStage2Policy(args).to(device)
    load_stage_ckpt(stage1, args.stage1_ckpt, device, expected_stage='stage1', strict=False)
    load_stage_ckpt(stage2, args.stage2_ckpt, device, expected_stage='stage2', strict=False)
    stage1.eval()
    stage2.eval()

    if args.eval_mode == 'dataset':
        _evaluate_on_dataset(args, device, stage1, stage2)
    elif args.eval_mode == 'synthetic':
        _evaluate_on_synthetic(args, device, stage1, stage2)
    elif args.eval_mode == 'tsplib':
        _evaluate_on_tsplib(args, device, stage1, stage2)
    elif args.eval_mode == 'cvrplib':
        raise ValueError("Use --problem cvrp for CVRPLIB evaluation (single-stage VRP).")
    else:
        raise ValueError("eval_mode must be one of: dataset, synthetic, tsplib, cvrplib")


def main():
    args = build_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    evaluate_two_stage(args, device)


if __name__ == '__main__':
    main()
