# evaluate_op_diffusion.py

import torch
import numpy as np
import os
import time
import argparse
import matplotlib.pyplot as plt
from omegaconf import OmegaConf, DictConfig
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# --- OP-specific imports from your project ---
from data_loader_sparse import OPConditionalSuffixDataset, op_custom_collate_fn
from diffusion_model_sparse import ConditionalOPSuffixDiffusionModel
from discrete_diffusion_sparse import AdjacencyMatrixDiffusion

# ==============================================================================
# === OP-SPECIFIC HELPER AND DECODING FUNCTIONS ===
# ==============================================================================

def calculate_op_metrics(instance_locs, tour, prizes):
    """
    Calculates the total prize and distance for an OP solution.
    The tour is a single list of customer nodes, e.g., [3, 1, 5].
    """
    if not tour:
        return 0.0, 0.0

    # Create the full tour including start/end at the depot (node 0)
    full_tour_indices = [0] + tour + [0]
    tour_locs = instance_locs[full_tour_indices]
    
    # Calculate total distance
    tour_distances = torch.sqrt(torch.sum((tour_locs[:-1] - tour_locs[1:])**2, dim=1))
    total_distance = torch.sum(tour_distances).item()
    
    # Calculate total prize (sum of prizes of visited customer nodes)
    # Note: prize at index 0 (depot) is assumed to be 0.
    total_prize = prizes[tour].sum().item()
    
    return total_prize, total_distance

def decode_op_solution_from_heatmap(adj_probs, instance_locs, prizes, max_length):
    """
    Constructs an OP solution from a heatmap using a guided greedy heuristic.
    It iteratively adds the most promising node that keeps the tour within the max_length budget.
    """
    device = adj_probs.device
    num_nodes = instance_locs.shape[0]
    
    # Symmetrize probabilities and pre-calculate all distances
    adj_probs = (adj_probs + adj_probs.T) / 2.0
    dists = torch.cdist(instance_locs, instance_locs, p=2)

    unvisited_customers = set(range(1, num_nodes))
    tour = []
    current_node = 0
    current_length = 0.0
    
    while unvisited_customers:
        best_score = -float('inf')
        best_next_node = -1

        # Find the best valid next customer to visit
        for next_node in unvisited_customers:
            # Check if adding this node and returning to depot is feasible
            potential_new_length = current_length + dists[current_node, next_node] + dists[next_node, 0]
            
            if potential_new_length <= max_length:
                # Score: guided by prize and model's heatmap, penalized by distance
                # A high prize and high probability is good, a long edge is bad.
                score = (prizes[next_node] * adj_probs[current_node, next_node]) / (dists[current_node, next_node] + 1e-6)
                
                if score > best_score:
                    best_score = score
                    best_next_node = next_node
        
        # If a valid next node was found, add it to the tour
        if best_next_node != -1:
            tour.append(best_next_node)
            current_length += dists[current_node, best_next_node]
            current_node = best_next_node
            unvisited_customers.remove(best_next_node)
        else:
            # If no more nodes can be added without violating the length constraint, stop.
            break
            
    return tour

def visualize_op_solution(instance_locs, tour, prizes, title="OP Solution", reward=None, cost=None, ax=None):
    """Visualizes an OP solution."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    
    locs_cpu = instance_locs.cpu().numpy()
    prizes_cpu = prizes.cpu().numpy()
    
    # Draw depot
    ax.scatter(locs_cpu[0, 0], locs_cpu[0, 1], c='black', s=150, label='Depot', zorder=5, marker='s')
    
    # Draw all customers, with size proportional to prize
    customer_locs = locs_cpu[1:]
    customer_prizes = prizes_cpu[1:]
    ax.scatter(customer_locs[:, 0], customer_locs[:, 1], c='lightblue', s=20 + customer_prizes * 100, label='Unvisited Customers', zorder=2)
    
    if tour:
        # Highlight visited customers
        visited_indices = np.array(tour)
        visited_locs = locs_cpu[visited_indices]
        visited_prizes = prizes_cpu[visited_indices]
        ax.scatter(visited_locs[:, 0], visited_locs[:, 1], c='red', s=20 + visited_prizes * 100, label='Visited Customers', zorder=3)
        
        # Draw the tour path
        tour_with_depot = [0] + tour + [0]
        tour_locs_plot = locs_cpu[tour_with_depot]
        ax.plot(tour_locs_plot[:, 0], tour_locs_plot[:, 1], color='maroon', marker='o', markersize=4, zorder=1, linestyle='-')
    
    plot_title = title
    if reward is not None and cost is not None:
        plot_title += f"\nReward: {reward:.4f} | Cost: {cost:.2f}"
        
    ax.set_title(plot_title)
    ax.legend()
    ax.set_aspect('equal', adjustable='box')

@torch.no_grad()
def evaluate(cfg: DictConfig, model_checkpoint_path: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Data Loading for OP ---
    dataset = OPConditionalSuffixDataset(
        txt_file_paths=[cfg.data.test_path],
        prefix_k_options=[0], # Prefix is 0 for from-scratch evaluation
        sparse_factor=cfg.model.sparse_factor
    )
    num_samples_to_evaluate = min(cfg.eval.num_samples_to_eval, len(dataset))
    eval_dataset = torch.utils.data.Subset(dataset, range(num_samples_to_evaluate))
    dataloader = DataLoader(
        eval_dataset, batch_size=cfg.eval.batch_size,
        shuffle=False, collate_fn=op_custom_collate_fn
    )

    # --- Load OP Model ---
    model = ConditionalOPSuffixDiffusionModel(
        num_nodes=cfg.model.num_nodes,
        # ... (all other model params from your config) ...
        node_coord_dim=cfg.model.node_coord_dim,
        pos_embed_num_feats=cfg.model.pos_embed_num_feats,
        node_embed_dim=cfg.model.node_embed_dim,
        prefix_node_embed_dim=cfg.model.node_embed_dim,
        prefix_enc_hidden_dim=cfg.model.prefix_enc_hidden_dim,
        prefix_cond_dim=cfg.model.prefix_cond_dim,
        max_length_embed_dim=cfg.model.max_length_embed_dim,
        gnn_n_layers=cfg.model.gnn_n_layers,
        gnn_hidden_dim=cfg.model.gnn_hidden_dim,
        gnn_aggregation=cfg.model.gnn_aggregation,
        gnn_norm=cfg.model.gnn_norm,
        gnn_learn_norm=cfg.model.gnn_learn_norm,
        gnn_gated=cfg.model.gnn_gated,
        time_embed_dim=cfg.model.time_embed_dim,
        sparse_factor=cfg.model.sparse_factor
    ).to(device)
    model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))
    model.eval()
    print(f"Loaded trained OP model from {model_checkpoint_path}")

    # --- Diffusion Handler ---
    diffusion_handler = AdjacencyMatrixDiffusion(
        num_nodes=cfg.model.num_nodes,
        num_timesteps=cfg.diffusion.num_timesteps,
        device=device,
        sparse_factor=cfg.model.sparse_factor
    )
    
    total_generated_reward, total_gt_reward = 0.0, 0.0
    num_evaluated = 0
    start_time = time.time()
    num_visualized = 0

    # --- Evaluation Loop ---
    for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Solving OP Batches")):
        instance_locs_batch = batch_data["instance_locs_orig"].to(device)
        prizes_batch = batch_data["prizes"].to(device)
        max_lengths_batch = batch_data["max_lengths"].to(device)
        
        _, generated_adj_probs_batch, _ = diffusion_handler.p_sample_loop_ddim(
            denoiser_model=model,
            batch_data={k: v.to(device) for k, v in batch_data.items() if torch.is_tensor(v)},
            num_inference_steps=cfg.eval.num_inference_steps
        )
        
        for i in range(instance_locs_batch.shape[0]):
            adj_probs = generated_adj_probs_batch[i]
            locs = instance_locs_batch[i]
            prizes = prizes_batch[i]
            max_length = max_lengths_batch[i]

            # Decode solution and calculate its metrics
            decoded_tour = decode_op_solution_from_heatmap(adj_probs, locs, prizes, max_length)
            generated_reward, generated_cost = calculate_op_metrics(locs, decoded_tour, prizes)
            
            # Get Ground Truth metrics
            original_data_index = batch_idx * cfg.eval.batch_size + i
            gt_tour_nodes = [node.item() for node in dataset.instances[original_data_index]['tour'] if node.item() != 0]
            gt_reward, gt_cost = calculate_op_metrics(locs, gt_tour_nodes, prizes)
            
            total_generated_reward += generated_reward
            total_gt_reward += gt_reward
            num_evaluated += 1

            # Visualize if needed
            if num_visualized < cfg.eval.num_samples_to_visualize:
                vis_dir = cfg.eval.visualization_dir
                os.makedirs(vis_dir, exist_ok=True)
                fig, axes = plt.subplots(1, 2, figsize=(18, 8))
                
                visualize_op_solution(locs, decoded_tour, prizes, "Generated Solution", generated_reward, generated_cost, ax=axes[0])
                visualize_op_solution(locs, gt_tour_nodes, prizes, "Ground Truth", gt_reward, gt_cost, ax=axes[1])
                
                fig.suptitle(f"Instance #{original_data_index} Comparison (Max Length: {max_length.item():.2f})", fontsize=16)
                save_path = os.path.join(vis_dir, f"op_comparison_{original_data_index}.png")
                plt.savefig(save_path)
                plt.close(fig)
                print(f"Saved visualization to {save_path}")
                num_visualized += 1

    total_time = time.time() - start_time
    avg_generated_reward = total_generated_reward / num_evaluated if num_evaluated > 0 else 0.0
    avg_gt_reward = total_gt_reward / num_evaluated if num_evaluated > 0 else 0.0
    
    print("\n" + "="*60)
    print("--- OP Diffusion Model Evaluation Summary ---")
    print(f"Evaluated {num_evaluated} instances in {total_time:.2f}s.")
    print(f"Average Generated Reward: {avg_generated_reward:.4f}")
    print(f"Average Ground Truth Reward: {avg_gt_reward:.4f}")
    gap = ((avg_gt_reward - avg_generated_reward) / avg_gt_reward) * 100 if avg_gt_reward > 0 else 0.0
    print(f"Gap to Ground Truth: {gap:.2f}%")
    print("="*60)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate a trained Diffusion Model on the OP")
    parser.add_argument("--config", type=str, default="op100_config.yaml", help="Path to the base config file (e.g., op100_config.yaml)")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained model .pth checkpoint file")
    args = parser.parse_args()

    cfg = OmegaConf.load(args.config)
    
    # Add/override evaluation-specific parameters
    eval_cfg = OmegaConf.create({
        'data': {
            'test_path': "op_dataset/test.txt", # Default test path, can be overridden in config
        },
        'eval': {
            'batch_size': 16,
            'num_samples_to_eval': 500,
            'num_samples_to_visualize': 5,
            'visualization_dir': './eval_visualizations_op',
            'num_inference_steps': 50,
        }
    })
    final_cfg = OmegaConf.merge(cfg, eval_cfg)
    
    print("--- Final OP Evaluation Configuration ---")
    print(OmegaConf.to_yaml(final_cfg))
    
    if not os.path.exists(args.checkpoint):
        raise FileNotFoundError(f"Checkpoint not found at: {args.checkpoint}")

    evaluate(final_cfg, args.checkpoint)