# hybrid_op100_solver.py (Upgraded with KL Divergence Trigger)

import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import argparse
from tqdm.auto import tqdm
from omegaconf import OmegaConf, DictConfig
from torch.utils.data import DataLoader
from tensordict import TensorDict
from sklearn.neighbors import KDTree
from collections import defaultdict

# --- RL4CO Imports (for OP) ---
from rl4co.models.zoo import AttentionModel
from rl4co.envs import OPEnv

# --- Your Project's Custom Imports (for OP) ---
from diffusion_model_sparse import ConditionalOPSuffixDiffusionModel
from discrete_diffusion_sparse import AdjacencyMatrixDiffusion
from data_loader_sparse import OPConditionalSuffixDataset
from evaluation_op100 import calculate_op_metrics, decode_op_solution_from_heatmap

class HybridSolver:
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Solver using device: {self.device}")

        self.rl_policy, self.rl_env = self._load_rl_policy()
        self.dm_model = self._load_dm_model()
        
        self.diffusion_handler = AdjacencyMatrixDiffusion(
            num_nodes=cfg.model.num_nodes,
            num_timesteps=cfg.diffusion.num_timesteps,
            device=self.device,
            sparse_factor=cfg.model.get("sparse_factor", -1)
        )

    def _load_rl_policy(self):
        print(f"Loading RL model for {self.cfg.rl_model.problem} from: {self.cfg.rl_model.ckpt_path}")
        try:
            env = OPEnv(generator_kwargs={'num_loc': self.cfg.model.num_nodes - 1})
            model = AttentionModel.load_from_checkpoint(
                self.cfg.rl_model.ckpt_path, env=env, strict=False
            )
            policy = model.policy.to(self.device)
            policy.eval()
            return policy, model.env
        except Exception as e:
            print(f"Error loading RL model: {e}"); exit()

    def _load_dm_model(self):
        print(f"Loading Diffusion model for OP from: {self.cfg.dm_model.ckpt_path}")
        model = ConditionalOPSuffixDiffusionModel(
            num_nodes=self.cfg.model.num_nodes,
            node_coord_dim=self.cfg.model.node_coord_dim,
            pos_embed_num_feats=self.cfg.model.pos_embed_num_feats,
            node_embed_dim=self.cfg.model.node_embed_dim,
            prefix_node_embed_dim=self.cfg.model.node_embed_dim,
            prefix_enc_hidden_dim=self.cfg.model.prefix_enc_hidden_dim,
            prefix_cond_dim=self.cfg.model.prefix_cond_dim,
            max_length_embed_dim=self.cfg.model.max_length_embed_dim,
            gnn_n_layers=self.cfg.model.gnn_n_layers,
            gnn_hidden_dim=self.cfg.model.gnn_hidden_dim,
            gnn_aggregation=self.cfg.model.gnn_aggregation,
            gnn_norm=self.cfg.model.gnn_norm,
            gnn_learn_norm=self.cfg.model.gnn_learn_norm,
            gnn_gated=self.cfg.model.gnn_gated,
            time_embed_dim=self.cfg.model.time_embed_dim,
            sparse_factor=self.cfg.model.sparse_factor
        ).to(self.device)
        model.load_state_dict(torch.load(self.cfg.dm_model.ckpt_path, map_location=self.device, weights_only=True))
        model.eval()
        return model

    # <<< NEW: Upgraded Trigger with KL Divergence >>>
    def _should_trigger_dm(self, probs, mask, active_mask):
        """
        Determines whether to trigger the Diffusion Model based on uncertainty metrics.
        Combines Entropy and KL Divergence for a more robust decision.
        """
        if not active_mask.any():
            return torch.zeros_like(active_mask, dtype=torch.bool)

        # --- Metric 1: Entropy ---
        # High entropy means the probability is spread out, indicating uncertainty.
        entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
        trigger_by_entropy = entropy > self.cfg.solver.entropy_threshold

        # --- Metric 2: KL Divergence from Uniform ---
        # Low KL divergence means the model's output is close to random guessing.
        num_available_actions = mask.sum(dim=-1, keepdim=True).float()
        uniform_dist = mask.float() / (num_available_actions + 1e-9)
        
        # D_KL(P || Q) = sum(P(x) * log(P(x) / Q(x)))
        kl_div = torch.sum(probs * (torch.log(probs + 1e-9) - torch.log(uniform_dist + 1e-9)), dim=-1)
        trigger_by_kl = kl_div < self.cfg.solver.kl_div_threshold
        
        # Combine triggers: fire if EITHER condition is met
        is_uncertain_mask = trigger_by_entropy | trigger_by_kl
        
        return active_mask & is_uncertain_mask

    @torch.no_grad()
    def solve_batch_hybrid_vs_proposals(self, td_rl, td_dm, env):
        B, N_plus_1, _ = td_dm['locs'].shape
        device = self.device
        
        td_step = env.reset(td_rl.clone())
        hybrid_tours = [[] for _ in range(B)]
        
        dm_proposal_stats = [{"reward": torch.tensor(-1.0, device=device), "tour": [], "generation_step": -1} for _ in range(B)]

        node_embeds, _ = self.rl_policy.encoder(td_step)
        cached_embeds = self.rl_policy.decoder._precompute_cache(node_embeds)

        while not td_step["done"].all():
            current_step_index = len(hybrid_tours[0])
            
            logits, _ = self.rl_policy.decoder(td_step, cached_embeds)
            mask = td_step["action_mask"]
            probs = F.softmax(logits + mask.log(), dim=-1)
            rl_greedy_choice = probs.argmax(-1)
            
            active_mask = ~td_step["done"].squeeze(-1)
            if not active_mask.any(): break
            
            is_uncertain_mask = self._should_trigger_dm(probs, mask, active_mask)
            uncertain_indices = is_uncertain_mask.nonzero().squeeze(-1)
            
            if uncertain_indices.numel() > 0:
                print(f"--- Step {current_step_index}: DM triggered for {uncertain_indices.numel()}/{active_mask.sum().item()} instances. ---")
                
                # ... (The rest of the logic for preparing and running DM is the same) ...
                num_uncertain = uncertain_indices.numel()
                uncertain_probs = probs[uncertain_indices]
                
                sorted_probs, _ = torch.sort(uncertain_probs, dim=-1, descending=True)
                cum_probs = torch.cumsum(sorted_probs, dim=-1)
                n_candidates = torch.argmax((cum_probs >= self.cfg.solver.dynamic_n_cumulative_threshold).int(), dim=-1) + 1
                max_n = n_candidates.max().item()
                proposals = torch.topk(uncertain_probs, k=max_n, dim=1)[1]
                
                dm_prefixes, dm_to_instance_map = [], []
                for i, original_idx in enumerate(uncertain_indices):
                    path_so_far = [0] + hybrid_tours[original_idx.item()]
                    for j in range(n_candidates[i]):
                        candidate_node = proposals[i, j].item()
                        dm_prefixes.append(path_so_far + [candidate_node])
                        dm_to_instance_map.append(original_idx.item())

                dm_batch_data_list = [self._prepare_dm_input(td_dm[idx], prefix) for idx, prefix in zip(dm_to_instance_map, dm_prefixes)]
                
                node_counts = [N_plus_1] * len(dm_batch_data_list)
                node_to_graph_batch = torch.cat([torch.full((n,), i, dtype=torch.long, device=device) for i, n in enumerate(node_counts)])
                node_cumsum = torch.tensor([0] + list(np.cumsum(node_counts)[:-1]), dtype=torch.long, device=device)
                edge_indices = [d['edge_index'] + node_cumsum[i] for i, d in enumerate(dm_batch_data_list)]

                dm_batch_data = {
                    "prefix_nodes": torch.cat([d['prefix_nodes'] for d in dm_batch_data_list], dim=0),
                    "prefix_lengths": torch.cat([d['prefix_lengths'] for d in dm_batch_data_list], dim=0),
                    "instance_locs": torch.cat([d['instance_locs'] for d in dm_batch_data_list], dim=0),
                    "node_state_features": torch.cat([d['node_state_features'] for d in dm_batch_data_list], dim=0),
                    "edge_index": torch.cat(edge_indices, dim=1),
                    "dist_feature": torch.cat([d['dist_feature'] for d in dm_batch_data_list], dim=0),
                    "node_to_graph_batch": node_to_graph_batch,
                    "max_lengths": torch.cat([d['max_lengths'] for d in dm_batch_data_list], dim=0),
                }

                _, adj_probs_batch, _ = self.diffusion_handler.p_sample_loop_ddim(
                    denoiser_model=self.dm_model, batch_data=dm_batch_data,
                    num_inference_steps=self.cfg.solver.dm_inference_steps
                )

                dm_to_instance_map_tensor = torch.tensor(dm_to_instance_map, device=device)
                for i, original_idx in enumerate(uncertain_indices):
                    instance_mask = (dm_to_instance_map_tensor == original_idx)
                    instance_adj_probs = adj_probs_batch[instance_mask]
                    
                    for k in range(instance_adj_probs.shape[0]):
                        dm_tour = decode_op_solution_from_heatmap(
                            instance_adj_probs[k], td_dm['locs'][original_idx],
                            td_dm['prize'][original_idx], td_dm['max_length'][original_idx]
                        )
                        dm_reward, _ = calculate_op_metrics(td_dm['locs'][original_idx], dm_tour, td_dm['prize'][original_idx])
                        
                        if dm_reward > dm_proposal_stats[original_idx.item()]["reward"]:
                            dm_proposal_stats[original_idx.item()] = {
                                "reward": torch.tensor(dm_reward, device=device),
                                "tour": dm_tour, "generation_step": current_step_index
                            }

            for i in range(B):
                if not td_step["done"][i]:
                    hybrid_tours[i].append(rl_greedy_choice[i].item())
            
            td_step.set("action", rl_greedy_choice)
            td_step = env.step(td_step)["next"]

        final_solutions, run_statistics = [], []
        for i in range(B):
            hybrid_reward, _ = calculate_op_metrics(td_dm['locs'][i], hybrid_tours[i], td_dm['prize'][i])
            proposal_reward = dm_proposal_stats[i]["reward"].item()

            if proposal_reward > hybrid_reward:
                final_solutions.append(dm_proposal_stats[i]["tour"])
                run_statistics.append({"reward": proposal_reward, "source": "DM Proposal", **dm_proposal_stats[i]})
            else:
                final_solutions.append(hybrid_tours[i])
                run_statistics.append({"reward": hybrid_reward, "source": "Hybrid Path"})
                
        return final_solutions, run_statistics

    def _prepare_dm_input(self, td_dm_instance, prefix_nodes_list):
        instance_locs = td_dm_instance['locs']
        prizes = td_dm_instance['prize']
        max_length = td_dm_instance['max_length']
        N = self.cfg.model.num_nodes
        sparse_k = self.cfg.model.sparse_factor
        
        kdt = KDTree(instance_locs.cpu().numpy(), metric='euclidean')
        _, knn_indices = kdt.query(instance_locs.cpu().numpy(), k=sparse_k)
        source_nodes = torch.arange(N, device=self.device).view(-1, 1).repeat(1, sparse_k).flatten()
        target_nodes = torch.from_numpy(knn_indices).to(self.device).flatten()
        edge_index = torch.stack([source_nodes, target_nodes], dim=0)
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        edge_index = edge_index[:, edge_index[0] != edge_index[1]]
        edge_index = torch.unique(torch.sort(edge_index, dim=0)[0], dim=1)

        src, dst = edge_index[0], edge_index[1]
        distances = torch.linalg.norm(instance_locs[src] - instance_locs[dst], dim=-1)
        normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-8)
        
        dist_feature = normalized_distances.unsqueeze(-1)
        
        prefix_len = len(prefix_nodes_list)
        prefix_tensor = torch.tensor(prefix_nodes_list, device=self.device, dtype=torch.long)
        
        node_state_features = torch.zeros(N, 3, device=self.device)
        node_state_features[0, 0] = 1.0
        node_state_features[:, 1] = prizes
        if prefix_len > 0:
            node_state_features[prefix_tensor, 2] = 1.0

        return {
            "instance_locs": instance_locs.unsqueeze(0),
            "node_state_features": node_state_features.unsqueeze(0),
            "edge_index": edge_index,
            "dist_feature": dist_feature,
            "node_to_graph_batch": torch.zeros(N, dtype=torch.long, device=self.device),
            "max_lengths": max_length.unsqueeze(0),
            "prefix_nodes": F.pad(prefix_tensor, (0, N - prefix_len)).unsqueeze(0),
            "prefix_lengths": torch.tensor([prefix_len], device=self.device),
            "prizes": prizes.unsqueeze(0),
        }


def run(cfg: DictConfig):
    solver = HybridSolver(cfg)
    device = solver.device
    
    dataset = OPConditionalSuffixDataset(
        txt_file_paths=[cfg.data.test_path],
        prefix_k_options=[0],
        sparse_factor=cfg.model.sparse_factor
    )
    dataloader = DataLoader(dataset, batch_size=cfg.eval.batch_size, shuffle=False)

    all_stats = []
    start_time = time.time()
    
    for data_dict in tqdm(dataloader, desc="Solving OP Batches"):
        all_locs = data_dict["instance_locs"]
        all_prizes = data_dict["prizes"]
        B = all_locs.shape[0]

        td_rl = TensorDict({
            "depot": all_locs[:, 0, :],
            "locs": all_locs[:, 1:, :],
            "prize": all_prizes[:, 1:],
            "max_length": data_dict["max_length"],
        }, batch_size=B).to(device)

        td_dm = TensorDict({
            "locs": all_locs,
            "prize": all_prizes,
            "max_length": data_dict["max_length"],
        }, batch_size=B).to(device)
        
        _, batch_stats = solver.solve_batch_hybrid_vs_proposals(td_rl, td_dm, solver.rl_env)
        all_stats.extend(batch_stats)

    total_time = time.time() - start_time
    avg_reward = np.mean([s['reward'] for s in all_stats])
    
    print("\n" + "="*60)
    print("--- OP Hybrid Solver Evaluation Summary ---")
    print(f"Total time: {total_time:.2f}s for {len(all_stats)} instances.")
    print(f"Average Final Reward: {avg_reward:.4f}")
    
    dm_count = sum(1 for s in all_stats if s['source'] == 'DM Proposal')
    print(f"Solutions chosen from DM proposals: {dm_count}/{len(all_stats)}")
    print("="*60)
    
    # Optional: Print detailed stats for the first few instances
    print("\n--- Detailed Analysis of First 5 Instances ---")
    for i, stats in enumerate(all_stats[:5]):
        print(f"\nInstance {i}:")
        print(f"  Final Reward: {stats['reward']:.4f}")
        print(f"  Winning Source: {stats['source']}")
        if stats['source'] == 'DM Proposal':
            print(f"  DM Proposal generated at step: {stats['generation_step']}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Hybrid RL-DM Solver for OP")
    parser.add_argument("--config", type=str, default="hybrid_op100_config.yaml")
    args = parser.parse_args()
    
    cfg = OmegaConf.load(args.config)
    
    # Add default solver params if not in config
    solver_defaults = OmegaConf.create({
        "solver": {
            "dynamic_n_cumulative_threshold": 0.8,
        }
    })
    cfg = OmegaConf.merge(solver_defaults, cfg)
    
    run(cfg)