import torch
from torch import nn
from maskrcnn_benchmark.modeling.roi_heads.relation_head.clip import clip
from collections import OrderedDict
import csv
from pathlib import Path


class Clip_PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, tokenizer, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 2
        ctx_init = ''
        csc = cfg.MODEL.ROI_RELATION_HEAD.HP.SPECIFIC_PROMPT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        # [*] 加载 LLM 配置
        hp_cfg = getattr(cfg.MODEL.ROI_RELATION_HEAD, "HP", None)
        self.llm_prompt_test_enable = getattr(hp_cfg, "LLM_PROMPT_TEST_ENABLE", False) if hp_cfg is not None else False
        sections_raw = getattr(hp_cfg, "LLM_PROMPT_SECTIONS", ("description", "similar", "inverse")) if hp_cfg is not None else ()
        if isinstance(sections_raw, str):
            sections_raw = sections_raw.replace("|", ",").split(",")
        self.llm_prompt_sections = {s.strip().lower() for s in sections_raw if s and s.strip().lower() != "none"}
        
        # [*] 修复：移除强制只保留 inverse 的逻辑
        if not self.llm_prompt_sections or "none" in {s.lower() for s in sections_raw}:
            self.llm_prompt_sections = set()
        
        csv_path_raw = getattr(hp_cfg, "LLM_PROMPT_CSV_PATH", getattr(hp_cfg, "LLM_PROMPT_CSV", "")) if hp_cfg is not None else ""
        self._llm_metadata = self._load_llm_metadata(csv_path_raw)
        self._llm_prompts_available = False
        self._llm_prompt_logged = False
        
        classnames = [name.replace("_", " ") for name in classnames]
        self.classnames = classnames
        name_lens = [len(tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        
        non_bg_idx = 0
        for idx, name in enumerate(classnames):
            if name.strip().lower() not in ("background", "__background__"):
                non_bg_idx = idx
                break
        self._prompt_sample_idx = non_bg_idx
        self._base_prompt_strings = prompts
        self._active_prompt_strings = prompts
        
        print(f'[LLM Config] sections={sorted(self.llm_prompt_sections)}, enabled={self.llm_prompt_test_enable}')
        
        # [*] 构建基础提示嵌入
        tokenized_prompts, token_prefix, token_suffix = self._build_prompt_embeddings(
            prompts, clip_model, dtype, n_ctx
        )
        self.register_buffer("token_prefix", token_prefix)
        self.register_buffer("token_suffix", token_suffix)
        self.register_buffer("tokenized_prompts", tokenized_prompts)

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.n_cls = n_cls
        self.name_lens = name_lens
        default_pos = getattr(hp_cfg, "CLASS_TOKEN_POSITION", "middle") if hp_cfg is not None else "middle"
        if default_pos not in ("front", "middle", "end"):
            default_pos = "middle"
        self.class_token_position = default_pos
        self._prompt_logged_once = False

        # [*] 构建 LLM 增强测试提示
        if self.llm_prompt_test_enable and self._llm_metadata:
            prompts_llm = []
            llm_available = False
            for idx, name in enumerate(classnames):
                base_prompt = prompts[idx]
                meta = self._llm_metadata.get(name.lower())
                segments = []
                
                if meta:
                    # [*] description 部分保持不变（使用句号结尾）
                    if "description" in self.llm_prompt_sections:
                        desc = (meta.get("description") or "").strip()
                        if desc:
                            desc_clean = desc
                            if desc_clean.lower().startswith("spatial:"):
                                desc_clean = desc_clean[len("spatial:"):].strip()
                            desc_clean = desc_clean.rstrip(".")
                            if desc_clean:
                                segments.append(f"{desc_clean}.")
                    
                    # [*] similar 部分改为逗号分隔
                    if "similar" in self.llm_prompt_sections:
                        sim = (meta.get("similar") or "").strip()
                        if sim:
                            sim_clean = sim.rstrip(".,")
                            if sim_clean:
                                segments.append(f"as same as {sim_clean},")
                    
                    # [+] 修改：inverse 部分使用新的长前缀
                    if "inverse" in self.llm_prompt_sections:
                        inv = (meta.get("inverse") or "").strip()
                        if inv:
                            segments.append(f"When the subject and object exchange, the inverse relation is {inv.rstrip('.')}.")
                
                if segments:
                    llm_available = True
                    base_part = f"{prompt_prefix} {name}"
                    
                    # 分别处理不同部分的连接
                    extended_parts = []
                    for seg in segments:
                        if seg.startswith("is similar to"):
                            extended_parts.append(seg)
                        else:
                            extended_parts.append(seg)
                    
                    # 组合：基础部分 + 扩展部分
                    if extended_parts:
                        extended = " ".join(extended_parts)
                        # 确保 similar 前有逗号
                        if "is similar to" in extended:
                            parts = extended.split("is similar to", 1)
                            if len(parts) == 2:
                                prefix_part = parts[0].rstrip(". ")
                                prompts_llm.append(f"{base_part}. {prefix_part}, is similar to{parts[1]}")
                            else:
                                prompts_llm.append(f"{base_part}, {extended}")
                        else:
                            prompts_llm.append(f"{base_part}. {extended}")
                    else:
                        prompts_llm.append(base_prompt)
                else:
                    prompts_llm.append(base_prompt)
            
            if llm_available:
                tokenized_prompts_test, token_prefix_test, token_suffix_test = self._build_prompt_embeddings(
                    prompts_llm, clip_model, dtype, n_ctx
                )
                self.register_buffer("token_prefix_test", token_prefix_test)
                self.register_buffer("token_suffix_test", token_suffix_test)
                self.tokenized_prompts_test = tokenized_prompts_test
                self._llm_prompts_available = True
                self._llm_prompt_strings = prompts_llm
                self._active_prompt_strings = prompts_llm
                
                # [+] 打印实际测试阶段使用的完整提示
                print("\n" + "="*100)
                print(f"[LLM Enhanced Test Prompts] Total: {len(prompts_llm)} relations")
                print(f"[LLM Enhanced Test Prompts] Sections enabled: {sorted(self.llm_prompt_sections)}")
                print("="*100)
                for idx, (rel_name, prompt) in enumerate(zip(classnames, prompts_llm)):
                    base_part = f"{prompt_prefix} {rel_name}."
                    if len(prompt) > len(base_part):
                        extended_part = prompt[len(base_part):].strip()
                        print(f"[{idx:2d}] {rel_name:20s}")
                        print(f"     Base: {base_part}")
                        print(f"     +LLM: {extended_part}")
                    else:
                        print(f"[{idx:2d}] {rel_name:20s} -> {prompt} (no LLM extension)")
                print("="*100 + "\n")
            else:
                self.llm_prompt_test_enable = False
                self._llm_prompt_strings = None
                print("[LLM] No valid metadata found, LLM prompts disabled")
        else:
            self.tokenized_prompts_test = None
            self._llm_prompt_strings = None
            self._active_prompt_strings = prompts

        # [*] 打印基础提示示例
        self._log_prompt_preview("base", self._base_prompt_strings)

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        use_llm_prompt = (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available
        prefix = self.token_prefix_test if use_llm_prompt else self.token_prefix
        suffix = self.token_suffix_test if use_llm_prompt else self.token_suffix

        # [+] 测试阶段首次调用时打印确认信息
        if not self.training and not getattr(self, "_test_prompt_logged", False):
            if use_llm_prompt:
                print("\n" + "!"*80)
                print("[INFERENCE MODE] Using LLM-Enhanced Prompts")
                print(f"[Sample Prompt] {self._llm_prompt_strings[self._prompt_sample_idx]}")
                print("!"*80 + "\n")
            else:
                print("\n" + "!"*80)
                print("[INFERENCE MODE] Using Base Prompts (LLM disabled or unavailable)")
                print(f"[Sample Prompt] {self._base_prompt_strings[self._prompt_sample_idx]}")
                print("!"*80 + "\n")
            self._test_prompt_logged = True

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)
        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)
        else:
            raise ValueError

        return prompts

    def get_tokenized_prompts(self):
        """返回当前激活的 tokenized prompts (供文本编码器使用)"""
        if (not self.training) and self.llm_prompt_test_enable and self._llm_prompts_available:
            return self.tokenized_prompts_test
        return self.tokenized_prompts

    def _build_prompt_embeddings(self, prompts, clip_model, dtype, n_ctx):
        """构建提示嵌入的通用函数"""
        tokenized = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized).type(dtype)
        token_prefix = embedding[:, :1, :]
        token_suffix = embedding[:, 1 + n_ctx :, :]
        return tokenized, token_prefix, token_suffix

    def _load_llm_metadata(self, csv_path):
        """
        [*] 使用列索引直接读取CSV:
            - 第0列: relation (关系名)
            - 第1列: description (新的英文描述)
            - 第2列: similar relation (相似关系)
            - 第4列: 新的逆关系 (inverse relations)
        """
        records = {}
        if not csv_path:
            return records
        csv_path = Path(csv_path).expanduser()
        if not csv_path.is_file():
            print(f"[Clip_PromptLearner] LLM prompt csv not found: {csv_path}")
            return records
        
        try:
            with csv_path.open("r", encoding="utf-8-sig") as f:
                reader = csv.reader(f)
                
                # 跳过表头并打印以验证
                header = next(reader, None)
                if header:
                    print(f"[Clip_PromptLearner] CSV header: {header[:5]}")
                
                # 逐行读取数据
                for row in reader:
                    if len(row) < 4:  # 确保至少有4列
                        continue
                    
                    # [*] 直接使用列索引
                    relation = row[0].strip()  # 第0列: relation
                    if not relation or relation.lower() in ("relation", ""):
                        continue
                    
                    desc = row[1].strip()      # 第1列: description
                    similar = row[2].strip()   # 第2列: similar relation
                    inverse = row[4].strip() if len(row) > 4 else row[3].strip()  # 第4列优先，回退到第3列
                    
                    records[relation.lower()] = {
                        "description": desc,
                        "similar": similar,
                        "inverse": inverse,
                    }
                
                print(f"[Clip_PromptLearner] Loaded {len(records)} relation metadata from CSV")
                
                # [+] 打印前3个示例以验证加载
                if records:
                    print("[Clip_PromptLearner] Sample metadata (first 3):")
                    for idx, (rel, meta) in enumerate(list(records.items())[:3]):
                        print(f"  [{idx}] {rel}:")
                        print(f"      desc: {meta['description'][:60]}..." if len(meta['description']) > 60 else f"      desc: {meta['description']}")
                        print(f"      similar: {meta['similar']}")
                        print(f"      inverse: {meta['inverse']}")
                        
        except Exception as exc:
            print(f"[Clip_PromptLearner] Failed to load LLM csv ({csv_path}): {exc}")
            import traceback
            traceback.print_exc()
        
        return records

    def _log_prompt_preview(self, tag, prompt_list):
        """打印提示预览"""
        total = len(prompt_list)
        if total == 0:
            return
        limit = total if total <= 100 else min(10, total)
        for idx in range(limit):
            print(f'[Prompt sample - {tag} #{idx}] {prompt_list[idx]}')
        if limit < total:
            print(f'[Prompt sample - {tag}] ... ({total} total)')