import torch
from torch import nn
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from collections import OrderedDict
# [+] 新增: CSV 解析与路径工具
import csv
from pathlib import Path


class Clip_PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, tokenizer, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 2
        ctx_init = ''
        csc = cfg.MODEL.ROI_RELATION_HEAD.HP.SPECIFIC_PROMPT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        hp_cfg = getattr(cfg.MODEL.ROI_RELATION_HEAD, "HP", None)
        self.llm_prompt_test_enable = getattr(hp_cfg, "LLM_PROMPT_TEST_ENABLE", False) if hp_cfg is not None else False
        sections_raw = getattr(hp_cfg, "LLM_PROMPT_SECTIONS", ("description", "similar", "inverse")) if hp_cfg is not None else ()
        if isinstance(sections_raw, str):
            sections_raw = sections_raw.replace("|", ",").split(",")
        self.llm_prompt_sections = {s.strip().lower() for s in sections_raw if s and s.strip()}
        if not self.llm_prompt_sections or "none" in self.llm_prompt_sections:
            self.llm_prompt_sections = set()
        csv_path_raw = getattr(hp_cfg, "LLM_PROMPT_CSV_PATH", getattr(hp_cfg, "LLM_PROMPT_CSV", "")) if hp_cfg is not None else ""
        self._llm_metadata = self._load_llm_metadata(csv_path_raw)
        self._llm_prompts_available = False
        self._llm_prompt_logged = False

        classnames = [name.replace("_", " ") for name in classnames]
        self.classnames = classnames  # [+] 新增: 保存原始关系名称
        name_lens = [len(tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        non_bg_idx = 0
        for idx, name in enumerate(classnames):
            if name.strip().lower() not in ("background", "__background__"):
                non_bg_idx = idx
                break
        self._prompt_sample_idx = non_bg_idx  # [*] 选择首个非背景类
        self._base_prompt_strings = prompts
        self._active_prompt_strings = prompts

        print('[LLM Config]', sorted(self.llm_prompt_sections), 'enabled:', self.llm_prompt_test_enable)

        tokenized_prompts, token_prefix, token_suffix = self._build_prompt_embeddings(
            prompts, clip_model, dtype, n_ctx
        )
        self.register_buffer("token_prefix", token_prefix)
        self.register_buffer("token_suffix", token_suffix)
        self.register_buffer("tokenized_prompts", tokenized_prompts)

        inverse_map = self._load_inverse_mapping(getattr(hp_cfg, "INVERSE_PROMPT_CSV_PATH", "")) if hp_cfg is not None else {}
        inverse_names = []
        for name in classnames:
            key = name.strip().lower()
            inv = inverse_map.get(key, "")
            inverse_names.append(inv if inv else name)
        inverse_prompts = [prompt_prefix + " " + inv + "." for inv in inverse_names]
        tokenized_inverse, token_prefix_inverse, token_suffix_inverse = self._build_prompt_embeddings(
            inverse_prompts, clip_model, dtype, n_ctx
        )
        self.register_buffer("token_prefix_inverse", token_prefix_inverse)
        self.register_buffer("token_suffix_inverse", token_suffix_inverse)
        self.register_buffer("tokenized_prompts_inverse", tokenized_inverse)

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.name_lens = name_lens
        self.name_lens_inverse = [len(tokenizer.encode(inv)) for inv in inverse_names]
        default_pos = getattr(hp_cfg, "CLASS_TOKEN_POSITION", "middle") if hp_cfg is not None else "middle"
        if default_pos not in ("front", "middle", "end"):                                                  # [+]
            default_pos = "middle"                                                                        # [+]
        self.class_token_position = default_pos  # [+]
        self._prompt_logged_once = False  # [+]

        if self.llm_prompt_test_enable and self._llm_metadata:
            prompts_llm = []
            llm_available = False
            for idx, name in enumerate(classnames):
                base_prompt = prompts[idx]
                meta = self._llm_metadata.get(name.lower())
                segments = []
                if meta:
                    if "description" in self.llm_prompt_sections:
                        desc = (meta.get("description") or "").strip()
                        if desc:
                            desc_clean = desc
                            if desc_clean.lower().startswith("spatial:"):
                                desc_clean = desc_clean[len("spatial:"):].strip()
                            desc_clean = desc_clean.rstrip(".")
                            if desc_clean:
                                segments.append(f"{desc_clean}.")
                    if "similar" in self.llm_prompt_sections:
                        sim = (meta.get("similar") or "").strip()
                        if sim:
                            sim_clean = sim.rstrip(".")
                            if sim_clean:
                                segments.append(f"similar: {sim_clean}.")
                    if "inverse" in self.llm_prompt_sections:
                        inv = (meta.get("inverse") or "").strip()
                        if inv:
                            inv_clean = inv.rstrip(".")
                            if inv_clean:
                                segments.append(f"inverse relation: {inv_clean}.")

                if segments:
                    llm_available = True
                    extended = " ".join(segments)
                    prompts_llm.append(f"{prompt_prefix} {name}. {extended}")
                else:
                    prompts_llm.append(base_prompt)
            if llm_available:
                tokenized_prompts_test, token_prefix_test, token_suffix_test = self._build_prompt_embeddings(
                    prompts_llm, clip_model, dtype, n_ctx
                )
                self.register_buffer("token_prefix_test", token_prefix_test)
                self.register_buffer("token_suffix_test", token_suffix_test)
                self.tokenized_prompts_test = tokenized_prompts_test
                self._llm_prompts_available = True
                self._llm_prompt_strings = prompts_llm
                self._active_prompt_strings = prompts_llm
                self._print_llm_prompt_summary(prompts_llm, prompt_prefix)
            else:
                self.llm_prompt_test_enable = False
                self._llm_prompt_strings = None
        else:
            self.tokenized_prompts_test = None
            self._llm_prompt_strings = None
            self._active_prompt_strings = prompts

        self._log_prompt_preview("base", self._active_prompt_strings)
        self._log_prompt_preview("inverse", inverse_prompts)

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        use_llm_prompt = (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available
        prefix = self.token_prefix_test if use_llm_prompt else self.token_prefix
        suffix = self.token_suffix_test if use_llm_prompt else self.token_suffix
        name_lens = self.name_lens

        if (not self.training) and not getattr(self, "_test_prompt_logged", False):
            active = self._llm_prompt_strings if use_llm_prompt and getattr(self, "_llm_prompt_strings", None) else self._base_prompt_strings
            print("!" * 80)
            print("[INFERENCE MODE] Sample prompt:", active[self._prompt_sample_idx])
            print("!" * 80)
            self._test_prompt_logged = True

        if use_llm_prompt and getattr(self, "_llm_prompt_strings", None) and not getattr(self, "_llm_prompt_logged", False):
            print('[Prompt active - LLM]', self._llm_prompt_strings[self._prompt_sample_idx])
            self._llm_prompt_logged = True

        return self._assemble_prompts(ctx, prefix, suffix, name_lens)

    def get_tokenized_prompts(self):
        if (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available:
            return self.tokenized_prompts_test
        return self.tokenized_prompts

    def get_inverse_prompt_embeddings(self):
        if getattr(self, "token_prefix_inverse", None) is None:
            return None
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        return self._assemble_prompts(ctx, self.token_prefix_inverse, self.token_suffix_inverse, self.name_lens_inverse)

    def get_inverse_tokenized_prompts(self):
        return getattr(self, "tokenized_prompts_inverse", None)

    def _assemble_prompts(self, ctx, prefix, suffix, name_lens):
        if self.class_token_position == "end":
            return torch.cat([prefix, ctx, suffix], dim=1)
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompts.append(torch.cat([prefix_i, ctx_i_half1, class_i, ctx_i_half2, suffix_i], dim=1))
            return torch.cat(prompts, dim=0)
        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompts.append(torch.cat([prefix_i, class_i, ctx_i, suffix_i], dim=1))
            return torch.cat(prompts, dim=0)
        else:
            raise ValueError

    def _build_prompt_embeddings(self, prompts, clip_model, dtype, n_ctx):
        tokenized = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized).type(dtype)
        token_prefix = embedding[:, :1, :]
        token_suffix = embedding[:, 1 + n_ctx:, :]
        return tokenized, token_prefix, token_suffix

    def _load_llm_metadata(self, csv_path):
        records = {}
        if not csv_path:
            return records
        csv_path = Path(csv_path).expanduser()
        if not csv_path.is_file():
            print(f"[Clip_PromptLearner] LLM prompt csv not found: {csv_path}")
            return records
        try:
            with csv_path.open("r", encoding="utf-8-sig") as f:
                reader = csv.reader(f)
                header = next(reader, None)
                if header:
                    print(f"[Clip_PromptLearner] CSV header: {header[:5]}")
                for row in reader:
                    if len(row) < 4:
                        continue
                    relation = row[0].strip()
                    if not relation or relation.lower() in ("relation", ""):
                        continue
                    desc = row[1].strip()
                    similar = row[2].strip()
                    inverse = row[4].strip() if len(row) > 4 else row[3].strip()
                    records[relation.lower()] = {
                        "description": desc,
                        "similar": similar,
                        "inverse": inverse,
                    }
                print(f"[Clip_PromptLearner] Loaded {len(records)} relation metadata from CSV")
                if records:
                    print("[Clip_PromptLearner] Sample metadata (first 3):")
                    for idx, (rel, meta) in enumerate(list(records.items())[:3]):
                        print(f"  [{idx}] {rel}:")
                        print(f"      desc: {meta['description']}")
                        print(f"      similar: {meta['similar']}")
                        print(f"      inverse: {meta['inverse']}")
        except Exception as exc:
            print(f"[Clip_PromptLearner] Failed to load LLM csv ({csv_path}): {exc}")
        return records

    def _load_inverse_mapping(self, csv_path):
        records = {}
        if not csv_path:
            return records
        csv_path = Path(csv_path).expanduser()
        if not csv_path.is_file():
            print(f"[Clip_PromptLearner] inverse csv not found: {csv_path}")
            return records
        try:
            with csv_path.open("r", encoding="utf-8-sig") as f:
                reader = csv.DictReader(f)
                if reader.fieldnames and len(reader.fieldnames) >= 2:
                    keys = [h.lower() for h in reader.fieldnames]
                    for row in reader:
                        relation = (row.get("relation") or row.get(keys[0]) or "").strip().lower()
                        inverse = (row.get("inverse") or row.get(keys[1]) or "").strip()
                        if relation and inverse:
                            records[relation] = inverse
                else:
                    f.seek(0)
                    reader_plain = csv.reader(f)
                    for cols in reader_plain:
                        if len(cols) < 2:
                            continue
                        relation, inverse = cols[0].strip().lower(), cols[1].strip()
                        if relation and inverse:
                            records[relation] = inverse
        except Exception as exc:
            print(f"[Clip_PromptLearner] Failed to load inverse csv ({csv_path}): {exc}")
        return records

    def _log_prompt_preview(self, tag, prompt_list):
        total = len(prompt_list)
        if total == 0:
            return
        limit = total if total <= 100 else min(10, total)
        for idx in range(limit):
            print(f'[Prompt sample - {tag} #{idx}] {prompt_list[idx]}')
        if limit < total:
            print(f'[Prompt sample - {tag}] ... ({total} total)')

    def _print_llm_prompt_summary(self, prompts_llm, prompt_prefix):
        print("\n" + "=" * 100)
        print(f"[LLM Enhanced Prompts] Total: {len(prompts_llm)}")
        print(f"[LLM Enhanced Prompts] Sections enabled: {sorted(self.llm_prompt_sections)}")
        print("=" * 100)
        for idx, (rel_name, prompt) in enumerate(zip(self.classnames, prompts_llm)):
            base_part = f"{prompt_prefix} {rel_name}."
            if len(prompt) > len(base_part):
                extended_part = prompt[len(base_part):].strip()
                print(f"[{idx:2d}] {rel_name:20s}")
                print(f"     Base: {base_part}")
                print(f"     +LLM: {extended_part}")
            else:
                print(f"[{idx:2d}] {rel_name:20s} -> {prompt} (no LLM extension)")
        print("=" * 100 + "\n")