import torch
from torch import nn
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from collections import OrderedDict
# [+] 新增: CSV 解析与路径工具
from pathlib import Path
import csv


class Clip_PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, tokenizer, clip_model, obj_classes=None):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 2
        ctx_init = ''
        csc = cfg.MODEL.ROI_RELATION_HEAD.HP.SPECIFIC_PROMPT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)
        self._tokenizer = tokenizer              # [修改]
        self._prompt_prefix_text = prompt_prefix # [修改]
        self._clip_token_embedding = clip_model.token_embedding  # [修改]
        self._clip_token_dtype = dtype           # [修改]
        self.clip_model = clip_model               # 新增
        self.dtype = dtype                         # 新增
        self._super_enabled = False
        self._super_lookup = []
        self._super_token_cache = {}
        self._super_prompt_cache = {}
        self._super_prompt_logged_keys = set()

        hp_cfg = getattr(cfg.MODEL.ROI_RELATION_HEAD, "HP", None)
        self.llm_prompt_test_enable = getattr(hp_cfg, "LLM_PROMPT_TEST_ENABLE", False) if hp_cfg is not None else False
        sections_raw = getattr(hp_cfg, "LLM_PROMPT_SECTIONS", ("description", "similar", "inverse")) if hp_cfg is not None else ()
        if isinstance(sections_raw, str):
            sections_raw = sections_raw.replace("|", ",").split(",")
        self.llm_prompt_sections = {s.strip().lower() for s in sections_raw if s and s.strip()}
        if not self.llm_prompt_sections or "none" in self.llm_prompt_sections:
            self.llm_prompt_sections = set()
        csv_path_raw = getattr(hp_cfg, "LLM_PROMPT_CSV_PATH", getattr(hp_cfg, "LLM_PROMPT_CSV", "")) if hp_cfg is not None else ""
        self._llm_metadata = self._load_llm_metadata(csv_path_raw)
        self._llm_prompts_available = False
        self._llm_prompt_logged = False

        classnames = [name.replace("_", " ") for name in classnames]
        self.classnames = classnames  # 保留原关系名
        hp_cfg = getattr(cfg.MODEL.ROI_RELATION_HEAD, "HP", None)

        # 基础 prompt 与长度
        prompts = [f"{prompt_prefix} {name}." for name in classnames]
        base_name_lens = [len(tokenizer.encode(name)) for name in classnames]

        # 读取逆关系映射（支持 i1 / i2 列）
        inverse_map = self._load_inverse_mapping(
            getattr(hp_cfg, "INVERSE_PROMPT_CSV_PATH", ""),
            mode=getattr(hp_cfg, "INVERSE_PROMPT_COL", "i1"),
        ) if hp_cfg is not None else {}
        if inverse_map:
            inv_names = [inverse_map.get(name.lower(), name) for name in classnames]
            inverse_prompts = [f"{prompt_prefix} {name}." for name in inv_names]
            inverse_name_lens = [len(tokenizer.encode(name)) for name in inv_names]
            # [FIX] 必须传递完整的 inverse_prompts 而不是 inv_names
            tokenized_inv, token_prefix_inv, token_suffix_inv = self._build_prompt_embeddings(inverse_prompts, clip_model, dtype, n_ctx)
            self.register_buffer("token_prefix_inverse", token_prefix_inv)
            self.register_buffer("token_suffix_inverse", token_suffix_inv)
            self.name_lens_inverse = inverse_name_lens
            self.register_buffer("tokenized_prompts_inverse", tokenized_inv)
            self.inverse_classnames = inv_names
        else:
            self.token_prefix_inverse = None
            self.token_suffix_inverse = None
            self.name_lens_inverse = None
            self.tokenized_prompts_inverse = None
            inverse_prompts = []
            self.inverse_classnames = classnames

        print('[LLM Config]', sorted(self.llm_prompt_sections), 'enabled:', self.llm_prompt_test_enable)

        tokenized_prompts, token_prefix, token_suffix = self._build_prompt_embeddings(
            prompts, clip_model, dtype, n_ctx
        )
        self.register_buffer("token_prefix", token_prefix)
        self.register_buffer("token_suffix", token_suffix)
        self.register_buffer("tokenized_prompts", tokenized_prompts)

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.name_lens = base_name_lens
        default_pos = getattr(hp_cfg, "CLASS_TOKEN_POSITION", "middle") if hp_cfg is not None else "middle"
        if default_pos not in ("front", "middle", "end"):                                                  # [+]
            default_pos = "middle"                                                                        # [+]
        self.class_token_position = default_pos  # [+]
        self._prompt_logged_once = False
        # 基础/当前提示字符串与示例索引
        self._base_prompt_strings = prompts
        self._active_prompt_strings = prompts
        self._base_prompt_extras = [""] * n_cls
        self._prompt_sample_idx = 0

        if self.llm_prompt_test_enable and self._llm_metadata:
            prompts_llm = []
            llm_available = False
            for idx, name in enumerate(classnames):
                base_prompt = prompts[idx]
                meta = self._llm_metadata.get(name.lower())
                segments = []
                if meta:
                    if "description" in self.llm_prompt_sections:
                        desc = (meta.get("description") or "").strip()
                        if desc:
                            desc_clean = desc
                            if desc_clean.lower().startswith("spatial:"):
                                desc_clean = desc_clean[len("spatial:"):].strip()
                            desc_clean = desc_clean.rstrip(".")
                            if desc_clean:
                                segments.append(f"{desc_clean}.")
                    if "similar" in self.llm_prompt_sections:
                        sim = (meta.get("similar") or "").strip()
                        if sim:
                            sim_clean = sim.rstrip(".")
                            if sim_clean:
                                segments.append(f"similar: {sim_clean}.")
                    if "inverse" in self.llm_prompt_sections:
                        inv = (meta.get("inverse") or "").strip()
                        if inv:
                            inv_clean = inv.rstrip(".")
                            if inv_clean:
                                segments.append(f"inverse relation: {inv_clean}.")

                if segments:
                    llm_available = True
                    extended = " ".join(segments)
                    prompts_llm.append(f"{prompt_prefix} {name}. {extended}")
                else:
                    prompts_llm.append(base_prompt)
            if llm_available:
                tokenized_prompts_test, token_prefix_test, token_suffix_test = self._build_prompt_embeddings(
                    prompts_llm, clip_model, dtype, n_ctx
                )
                self.register_buffer("token_prefix_test", token_prefix_test)
                self.register_buffer("token_suffix_test", token_suffix_test)
                self.tokenized_prompts_test = tokenized_prompts_test
                self._llm_prompts_available = True
                self._llm_prompt_strings = prompts_llm
                self._llm_prompt_extras = [
                    (pl[len(base):].strip() if len(pl) > len(base) else "")
                    for pl, base in zip(prompts_llm, self._base_prompt_strings)
                ]
                # 切换当前激活提示为 LLM 版
                self._active_prompt_strings = prompts_llm
            else:
                self.llm_prompt_test_enable = False
                self._llm_prompt_strings = None
        else:
            self.tokenized_prompts_test = None
            self._llm_prompt_strings = None
            self._llm_prompt_extras = [""] * n_cls
        self._log_prompt_preview("base", self._active_prompt_strings)
        self._log_prompt_preview("inverse", inverse_prompts)
        # self.inverse_classnames 已在上方设置
        if hp_cfg is not None:
            self._init_superclass_support(hp_cfg, obj_classes or [])  # [新增]

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        use_llm_prompt = (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available
        prefix = self.token_prefix_test if use_llm_prompt else self.token_prefix
        suffix = self.token_suffix_test if use_llm_prompt else self.token_suffix
        name_lens = self.name_lens

        if (not self.training) and not getattr(self, "_test_prompt_logged", False):
            active = self._llm_prompt_strings if use_llm_prompt and getattr(self, "_llm_prompt_strings", None) else self._base_prompt_strings
            print("!" * 80)
            print("[INFERENCE MODE] Sample prompt:", active[self._prompt_sample_idx])
            print("!" * 80)
            self._test_prompt_logged = True

        if use_llm_prompt and getattr(self, "_llm_prompt_strings", None) and not getattr(self, "_llm_prompt_logged", False):
            print('[Prompt active - LLM]', self._llm_prompt_strings[self._prompt_sample_idx])
            self._llm_prompt_logged = True

        return self._assemble_prompts(ctx, prefix, suffix, name_lens)

    def get_tokenized_prompts(self):
        if (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available:
            return self.tokenized_prompts_test
        return self.tokenized_prompts  # [修复]

    def get_inverse_tokenized_prompts(self):
        return getattr(self, "tokenized_prompts_inverse", None)

    # 新增：返回逆关系的提示嵌入，用于推理阶段主客体翻转
    def get_inverse_prompt_embeddings(self):
        if getattr(self, "token_prefix_inverse", None) is None:
            return None
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        return self._assemble_prompts(
            ctx,
            self.token_prefix_inverse,
            self.token_suffix_inverse,
            self.name_lens_inverse,
        )

    def _assemble_prompts(self, ctx, prefix, suffix, name_lens):
        if self.class_token_position == "end":
            return torch.cat([prefix, ctx, suffix], dim=1)
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompts.append(torch.cat([prefix_i, ctx_i_half1, class_i, ctx_i_half2, suffix_i], dim=1))
            return torch.cat(prompts, dim=0)
        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompts.append(torch.cat([prefix_i, class_i, ctx_i, suffix_i], dim=1))
            return torch.cat(prompts, dim=0)
        else:
            raise ValueError

    def _build_prompt_embeddings(self, prompts, clip_model, dtype, n_ctx):
        tokenized = torch.cat([clip.tokenize(p) for p in prompts])
        tokenized = tokenized.to(clip_model.token_embedding.weight.device)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized).type(dtype)
        token_prefix = embedding[:, :1, :]
        token_suffix = embedding[:, 1 + n_ctx:, :]
        return tokenized, token_prefix, token_suffix

    def _load_llm_metadata(self, csv_path):
        records = {}
        if not csv_path:
            return records
        csv_path = Path(csv_path).expanduser()
        if not csv_path.is_file():
            print(f"[Clip_PromptLearner] LLM prompt csv not found: {csv_path}")
            return records
        try:
            with csv_path.open("r", encoding="utf-8-sig") as f:
                reader = csv.reader(f)
                header = next(reader, None)
                if header:
                    print(f"[Clip_PromptLearner] CSV header: {header[:5]}")
                for row in reader:
                    if len(row) < 4:
                        continue
                    relation = row[0].strip()
                    if not relation or relation.lower() in ("relation", ""):
                        continue
                    desc = row[1].strip()
                    similar = row[2].strip()
                    inverse = row[4].strip() if len(row) > 4 else row[3].strip()
                    records[relation.lower()] = {
                        "description": desc,
                        "similar": similar,
                        "inverse": inverse,
                    }
                print(f"[Clip_PromptLearner] Loaded {len(records)} relation metadata from CSV")
                if records:
                    print("[Clip_PromptLearner] Sample metadata (first 3):")
                    for idx, (rel, meta) in enumerate(list(records.items())[:3]):
                        print(f"  [{idx}] {rel}:")
                        print(f"      desc: {meta['description']}")
                        print(f"      similar: {meta['similar']}")
                        print(f"      inverse: {meta['inverse']}")
        except Exception as exc:
            print(f"[Clip_PromptLearner] Failed to load LLM csv ({csv_path}): {exc}")
        return records

    def _load_inverse_mapping(self, csv_path, mode="i1"):
        """
        CSV 列: original_relation, inverse_passive, inverse_semantic, hbt, super, super2
        mode: 'i1' -> inverse_passive, 'i2' -> inverse_semantic
        """
        if not csv_path:
            return {}
        p = Path(csv_path)
        if not p.exists():
            print(f"[Inverse] CSV not found: {p}")
            return {}
        with p.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            headers = [h.lower() for h in reader.fieldnames]
            col_map = {"i1": "inverse_passive", "i2": "inverse_semantic"}
            target_col = col_map.get(mode, "inverse_passive")
            if target_col not in headers:
                print(f"[Inverse] target column '{target_col}' not in CSV, disable inverse.")
                return {}
            mapping = {}
            for row in reader:
                orig = (row.get("original_relation") or "").strip().lower()
                inv = (row.get(target_col) or "").strip()
                if orig and inv:
                    mapping[orig] = inv
            print(f"[Inverse] loaded {len(mapping)} mappings from '{target_col}'")
            return mapping

    def _log_prompt_preview(self, tag, prompt_list):
        total = len(prompt_list)
        if total == 0:
            return
        limit = total if total <= 100 else min(10, total)
        for idx in range(limit):
            print(f'[Prompt sample - {tag} #{idx}] {prompt_list[idx]}')
        if limit < total:
            print(f'[Prompt sample - {tag}] ... ({total} total)')

    def _print_llm_prompt_summary(self, prompts_llm, prompt_prefix):
        print("\n" + "=" * 100)
        print(f"[LLM Enhanced Prompts] Total: {len(prompts_llm)}")
        print(f"[LLM Enhanced Prompts] Sections enabled: {sorted(self.llm_prompt_sections)}")
        print("=" * 100)
        for idx, (rel_name, prompt) in enumerate(zip(self.classnames, prompts_llm)):
            base_part = f"{prompt_prefix} {rel_name}."
            if len(prompt) > len(base_part):
                extended_part = prompt[len(base_part):].strip()
                print(f"[{idx:2d}] {rel_name:20s}")
                print(f"     Base: {base_part}")
                print(f"     +LLM: {extended_part}")
            else:
                print(f"[{idx:2d}] {rel_name:20s} -> {prompt} (no LLM extension)")
        print("=" * 100 + "\n")

    def _init_superclass_support(self, hp_cfg, obj_classes):  # [新增]
        csv_path = getattr(hp_cfg, "OBJ_SUPERCLASS_CSV_PATH", "")
        fallback = getattr(hp_cfg, "OBJ_SUPERCLASS_FALLBACK", "").strip() or "object"
        use_flag = getattr(hp_cfg, "USE_SUPERCLASS_PROMPT", False)
        if not use_flag or not obj_classes:
            return
        mapping = {}
        if csv_path:
            path = Path(csv_path).expanduser()
            if path.is_file():
                try:
                    with path.open("r", encoding="utf-8-sig") as f:
                        reader = csv.reader(f)
                        header = next(reader, None)
                        if header and len(header) >= 2:
                            pass
                        else:
                            if header:
                                mapping[header[0].strip().lower()] = header[1].strip()
                        for row in reader:
                            if len(row) < 2:
                                continue
                            obj = row[0].strip().lower()
                            sup = row[1].strip()
                            if obj and sup:
                                mapping[obj] = sup
                except Exception as exc:
                    print(f"[Clip_PromptLearner] Failed to load superclass csv ({path}): {exc}")
        unique_supers = set()
        lookup = []
        for name in obj_classes:
            sup = mapping.get(name.lower(), fallback)
            lookup.append(sup)
            unique_supers.add(sup)
        if not unique_supers:
            return
        with torch.no_grad():
            for sup in unique_supers:
                tokenized = clip.tokenize(sup)
                emb = self._clip_token_embedding(tokenized).type(self._clip_token_dtype)
                self._super_token_cache[sup.lower()] = {
                    "tokenized": tokenized,
                    "embedding": emb[:, 1:-1, :].contiguous()
                }
        self._super_lookup = lookup
        self._super_fallback = fallback
        self._super_enabled = True

    def super_enabled(self):  # [新增]
        return self._super_enabled

    def get_superclass_name(self, obj_idx):  # [新增]
        if not self._super_enabled:
            return ""
        if 0 <= obj_idx < len(self._super_lookup):
            return self._super_lookup[obj_idx]
        return self._super_fallback

    def prepare_super_prompts(self, subject_super, object_super, inverse=False, use_llm=None, **kwargs):
        if not self._super_enabled:
            raise RuntimeError("Superclass prompt requested but mapping is disabled.")
        use_llm = self._resolve_use_llm_flag(use_llm)
        subj_clean = (subject_super or self._super_fallback).strip()
        obj_clean = (object_super or self._super_fallback).strip()
        rel_names = self.inverse_classnames if inverse else self.classnames
        extras_list = self._llm_prompt_extras if use_llm else self._base_prompt_extras
        key = (subj_clean.lower(), obj_clean.lower(), bool(inverse), bool(use_llm))
        cache = self._super_prompt_cache.get(key)
        if cache is None:
            prompts = []
            name_lens = []
            for idx, rel_name in enumerate(rel_names):
                core_tokens = " ".join(filter(None, [subj_clean, rel_name, obj_clean])).strip()
                if not core_tokens:
                    core_tokens = rel_name
                prompt_full = f"{self._prompt_prefix_text} {core_tokens}".strip() + "."
                extra = extras_list[idx] if idx < len(extras_list) else ""
                if extra:
                    prompt_full = f"{prompt_full} {extra}".strip()
                prompts.append(prompt_full)
                name_lens.append(len(self._tokenizer.encode(core_tokens)))
            tokenized_dyn, token_prefix, token_suffix = self._build_prompt_embeddings_dynamic(prompts)
            tokenized_cpu = tokenized_dyn.cpu()
            cache = (tokenized_cpu, token_prefix.cpu(), token_suffix.cpu(), name_lens, prompts)
            self._super_prompt_cache[key] = cache
        tokenized_cpu, token_prefix, token_suffix, name_lens, prompt_strings = cache
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        device = ctx.device
        prompt_embeddings = self._assemble_prompts(
            ctx,
            token_prefix.to(device),
            token_suffix.to(device),
            name_lens,
        )
        tokenized = tokenized_cpu.to(self._clip_token_embedding.weight.device, non_blocking=True)
        if key not in self._super_prompt_logged_keys and prompt_strings:
            sample_idx = min(self._prompt_sample_idx, len(prompt_strings) - 1)
            print(f"[Prompt sample - super {subj_clean}->{obj_clean} (inverse={inverse})] {prompt_strings[sample_idx]}")
            self._super_prompt_logged_keys.add(key)
        return prompt_embeddings, tokenized

    def _build_prompt_embeddings_dynamic(self, prompts):  # [新增]
        tokenized = torch.cat([clip.tokenize(p) for p in prompts], dim=0)
        tokenized = tokenized.to(self._clip_token_embedding.weight.device)
        with torch.no_grad():
            embedding = self._clip_token_embedding(tokenized).type(self._clip_token_dtype)
        token_prefix = embedding[:, :1, :]
        token_suffix = embedding[:, 1 + self.n_ctx:, :]
        return tokenized, token_prefix, token_suffix

    def _resolve_use_llm_flag(self, explicit):
        if explicit is not None:
            return explicit
        return (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_availableavailable