import torch
from torch import nn
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from collections import OrderedDict
# [+] 新增: CSV 解析与路径工具
import csv
from pathlib import Path


class Clip_PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, tokenizer, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 2
        ctx_init = ''
        csc = cfg.MODEL.ROI_RELATION_HEAD.HP.SPECIFIC_PROMPT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        hp_cfg = getattr(cfg.MODEL.ROI_RELATION_HEAD, "HP", None)
        self.llm_prompt_test_enable = getattr(hp_cfg, "LLM_PROMPT_TEST_ENABLE", False) if hp_cfg is not None else False
        sections_raw = getattr(hp_cfg, "LLM_PROMPT_SECTIONS", ("description", "similar", "inverse")) if hp_cfg is not None else ()
        if isinstance(sections_raw, str):
            sections_raw = sections_raw.replace("|", ",").split(",")
        self.llm_prompt_sections = {s.strip().lower() for s in sections_raw if s and s.strip().lower() != "none"}
        if "inverse" in self.llm_prompt_sections:
            self.llm_prompt_sections = {"inverse"}  # [*] 仅保留逆关系
        else:
            self.llm_prompt_sections = set()
        csv_path_raw = getattr(hp_cfg, "LLM_PROMPT_CSV_PATH", getattr(hp_cfg, "LLM_PROMPT_CSV", "")) if hp_cfg is not None else ""
        self._llm_metadata = self._load_llm_metadata(csv_path_raw)  # [+]
        self._llm_prompts_available = False  # [+]
        self._llm_prompt_logged = False      # [+]
        classnames = [name.replace("_", " ") for name in classnames]
        self.classnames = classnames  # [+] 新增: 保存原始关系名称
        name_lens = [len(tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        non_bg_idx = 0
        for idx, name in enumerate(classnames):
            if name.strip().lower() not in ("background", "__background__"):
                non_bg_idx = idx
                break
        self._prompt_sample_idx = non_bg_idx  # [*] 选择首个非背景类
        self._base_prompt_strings = prompts
        self._active_prompt_strings = prompts  # [+] 新增: 记录当前实际使用的提示
        print('[LLM] sections:', sorted(self.llm_prompt_sections), 'enabled:', self.llm_prompt_test_enable)
        # [*] 修改: 通过统一函数构建基础提示嵌入
        tokenized_prompts, token_prefix, token_suffix = self._build_prompt_embeddings(
            prompts, clip_model, dtype, n_ctx
        )
        self.register_buffer("token_prefix", token_prefix)
        self.register_buffer("token_suffix", token_suffix)
        self.register_buffer("tokenized_prompts", tokenized_prompts)

        self.n_cls = n_cls                  # [+] 创建 forward 所需属性
        self.n_ctx = n_ctx                  # [+]
        self.name_lens = name_lens          # [+]
        default_pos = getattr(hp_cfg, "CLASS_TOKEN_POSITION", "middle") if hp_cfg is not None else "middle"  # [+]
        if default_pos not in ("front", "middle", "end"):                                                  # [+]
            default_pos = "middle"                                                                        # [+]
        self.class_token_position = default_pos  # [+]
        self._prompt_logged_once = False  # [+]

        if self.llm_prompt_test_enable and self._llm_metadata:
            prompts_llm = []
            llm_available = False
            for idx, name in enumerate(classnames):
                base_prompt = prompts[idx]
                meta = self._llm_metadata.get(name.lower())
                segments = []
                if meta:
                    if "description" in self.llm_prompt_sections:
                        desc = (meta.get("description") or "").strip()
                        if desc:
                            desc_clean = desc.rstrip(".")
                            if desc_clean.lower().startswith("spatial:"):
                                desc_clean = desc_clean[len("spatial:"):].strip()
                            if desc_clean:
                                segments.append(f"description: {desc_clean}.")
                    
                    if "similar" in self.llm_prompt_sections:
                        sim = (meta.get("similar") or "").strip()
                        if sim:
                            sim_clean = sim.rstrip(".")
                            if sim_clean:
                                segments.append(f"similar: {sim_clean}.")
                    
                    if "inverse" in self.llm_prompt_sections:
                        inv = (meta.get("inverse") or "").strip()
                        if inv:
                            segments.append(f"inverse relation: {inv.rstrip('.')}.")
                
                if segments:
                    llm_available = True
                    extended = " ".join(segments)
                    prompts_llm.append(f"{prompt_prefix} {name}. {extended}")
                else:
                    prompts_llm.append(base_prompt)
            
            if llm_available:
                tokenized_prompts_test, token_prefix_test, token_suffix_test = self._build_prompt_embeddings(
                    prompts_llm, clip_model, dtype, n_ctx
                )
                self.register_buffer("token_prefix_test", token_prefix_test)
                self.register_buffer("token_suffix_test", token_suffix_test)
                self.tokenized_prompts_test = tokenized_prompts_test
                self._llm_prompts_available = True
                self._llm_prompt_strings = prompts_llm
                self._active_prompt_strings = prompts_llm
                
                # [+] 新增：打印所有扩展后的提示
                print("\n" + "="*80)
                print("[LLM Enhanced Prompts] Total relations:", len(prompts_llm))
                print("="*80)
                for idx, prompt in enumerate(prompts_llm):
                    print(f"[{idx:2d}] {classnames[idx]:20s} -> {prompt}")
                print("="*80 + "\n")
            else:
                self.llm_prompt_test_enable = False
                self._llm_prompt_strings = None
        else:
            self.tokenized_prompts_test = None
            self._llm_prompt_strings = None
            self._active_prompt_strings = prompts

        self._log_prompt_preview("base", self._active_prompt_strings)  # [*] 修改: 输出当前实际提示

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        use_llm_prompt = (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available
        prefix = self.token_prefix_test if use_llm_prompt else self.token_prefix
        suffix = self.token_suffix_test if use_llm_prompt else self.token_suffix

        if use_llm_prompt and getattr(self, "_llm_prompt_strings", None) and not getattr(self, "_llm_prompt_logged", False):
            print('[Prompt active - LLM]', self._llm_prompt_strings[0])  # [+] 新增: 推理时确认实际使用的模板
            self._llm_prompt_logged = True  # [+] 新增: 只打印一次以免刷屏

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

    # [+] 新增: 返回当前激活的 tokenized prompts（供文本编码器使用）
    def get_tokenized_prompts(self):
        if (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available:
            return self.tokenized_prompts_test
        return self.tokenized_prompts

    # [+] 新增: 构建提示嵌入的通用函数
    def _build_prompt_embeddings(self, prompts, clip_model, dtype, n_ctx):
        tokenized = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized).type(dtype)
        token_prefix = embedding[:, :1, :]
        token_suffix = embedding[:, 1 + n_ctx :, :]
        return tokenized, token_prefix, token_suffix

    # [+] 新增: 读取 LLM 描述 CSV
    def _load_llm_metadata(self, csv_path):
        """
        [+] 使用列索引直接读取CSV:
            - 第0列: relation (关系名)
            - 第1列: description (新的英文描述)
            - 第2列: similar relation (相似关系)
            - 第3列: 新的逆关系 (inverse relations)
        """
        records = {}
        if not csv_path:
            return records
        csv_path = Path(csv_path).expanduser()
        if not csv_path.is_file():
            print(f"[Clip_PromptLearner] LLM prompt csv not found: {csv_path}")
            return records
        
        try:
            with csv_path.open("r", encoding="utf-8-sig") as f:
                reader = csv.reader(f)  # [*] 改用 csv.reader 而非 DictReader
                
                # 跳过表头
                header = next(reader, None)
                if header:
                    print(f"[Clip_PromptLearner] CSV header: {header[:5]}")  # 打印前5列确认
                
                # 逐行读取数据
                for row in reader:
                    if len(row) < 4:  # 确保至少有4列
                        continue
                    
                    # [*] 直接使用列索引
                    relation = row[0].strip()  # 第0列: relation
                    if not relation or relation.lower() in ("relation", ""):
                        continue
                    
                    desc = row[1].strip()      # 第1列: description
                    similar = row[2].strip()   # 第2列: similar relation
                    inverse = row[3].strip()   # 第3列: 新的逆关系 (inverse relations)
                    
                    records[relation.lower()] = {
                        "description": desc,
                        "similar": similar,
                        "inverse": inverse,
                    }
                
                print(f"[Clip_PromptLearner] Loaded {len(records)} relation metadata from CSV")
                
                # [+] 打印前3个示例以验证加载
                if records:
                    print("[Clip_PromptLearner] Sample metadata (first 3):")
                    for idx, (rel, meta) in enumerate(list(records.items())[:3]):
                        print(f"  [{idx}] {rel}:")
                        print(f"      desc: {meta['description'][:60]}...")
                        print(f"      similar: {meta['similar']}")
                        print(f"      inverse: {meta['inverse']}")
                        
        except Exception as exc:
            print(f"[Clip_PromptLearner] Failed to load LLM csv ({csv_path}): {exc}")
            import traceback
            traceback.print_exc()
        
        return records
    def _log_prompt_preview(self, tag, prompt_list):  # [+]
        total = len(prompt_list)
        if total == 0:
            return
        limit = total if total <= 100 else min(10, total)
        for idx in range(limit):
            print(f'[Prompt sample - {tag} #{idx}] {prompt_list[idx]}')
        if limit < total:
            print(f'[Prompt sample - {tag}] ... ({total} total)')