import json
import torch
import random
import torch.nn.functional as F
from torch import Tensor
from typing import List, Dict, Tuple, Optional
from collections import Counter
import os
import requests
from openai import OpenAI
from utlis import PROMPT_EXTRACT_KEYWORDS, PROMPT_EXTRACT_DESCRIPTIONS, PROMPT_FINAL_ANALYSIS


def check_busy(api_url: str):
    print(f"api_url: {api_url}")
    try:
        resp = requests.get(api_url.replace('v1', 'metrics'))
        for line in resp.text.split('\n'):
            if 'num_requests_running' in line and not line.startswith('#'):
                running = int(float(line.split()[-1]))
            elif 'num_requests_waiting' in line and not line.startswith('#'):
                waiting = int(float(line.split()[-1]))
        return running, waiting
    except Exception as e:
        print(f"Error: {e}")
        return 999, 999


embedding_client = OpenAI(
    base_url="http://localhost:XXXX/v1",
    api_key="empty"
)


def get_embedding(text: str) -> list[float]:
    response = embedding_client.embeddings.create(
        model="Qwen3-Embedding-8B",
        input=text
    )
    return response.data[0].embedding


class MetaphorVU_Boost:
    def __init__(
        self,
        ckg_client,
        vlm_client_list,

        max_retry: int = 5,

        aug_mode: str = "mkg", 
        query_mode: str = "keywords",  
        reference_mode: str = "simple",  
        top_k: int = 10,
        indexing_type: str = "word", 
        num_hops: int = 1  
    ):
        self.ckg_client = ckg_client
        self.vlm_client_list = vlm_client_list
        self.vlm_client = self.vlm_client_list[0]

        self.max_retry = max_retry

        self.aug_mode = aug_mode
        self.query_mode = query_mode
        self.reference_mode = reference_mode
        self.top_k = top_k
        self.indexing_type = indexing_type
        self.num_hops = num_hops 

        assert aug_mode in ["mkg", "ckg", "rag", "self"], f'aug_mode must be "mkg" or "ckg" or "rag" or "self", got "{aug_mode}"'
        assert query_mode in ["keywords", "descriptions"], f'query_mode must be "keywords" or "descriptions", got "{query_mode}"'
        assert reference_mode in ["simple", "examples"], f'reference_mode must be "simple" or "examples", got "{reference_mode}"'
        assert indexing_type in ["word", "word_neighbors"], f'indexing_type must be "word" or "word_neighbors", got "{indexing_type}"'
        assert num_hops >= 1, f'num_hops must be >= 1, got {num_hops}' 
        
        print(f">>>>> Initializing MetaphorVU_Boost with parameters:")
        
        if self.aug_mode == "mkg":
            print("Loading metaphor graph")
            with open('/XXXX/utlis/metaphor_graph.json', 'r', encoding='utf-8') as f:
                self.graph = json.load(f)
            print(f"Graph loaded with {len(self.graph)} nodes")
            print("Loading graph embeddings")
            self.graph_embeddings = torch.load(f'/XXXX/utlis/metaphor_graph_embedding_{self.indexing_type}.pt', map_location='cpu')
            print(f"Graph embeddings shape: {self.graph_embeddings.shape}")
            
            self.word_to_idx = {}
            for idx_str, node in self.graph.items():
                self.word_to_idx[node["word"]] = int(idx_str)
                
        elif self.aug_mode == 'rag':
            print("Loading metaphor sentences")
            with open('/XXXX/utlis/metaphor_sentences.json', 'r', encoding='utf-8') as f:
                self.sentences_data = json.load(f)
            print(f"Sentences loaded with {len(self.sentences_data)} entries")
            print("Loading sentence embeddings")
            self.sentences_embeddings = torch.load('/XXXX/utlis/metaphor_sentences_embedding.pt', map_location='cpu')
            print(f"Sentences embeddings shape: {self.sentences_embeddings.shape}")
        elif self.aug_mode == 'ckg':
            assert self.query_mode == "keywords", "CKG mode only supports keyword queries"
            print("Using commonsense knowledge graph client")
        else:
            print("Using self-association mode (LLM-based)")

        print(f"> Aug mode: {aug_mode}")
        print(f"> Query mode: {query_mode}")
        print(f"> Reference mode: {reference_mode}")
        print(f"> Top K: {top_k}")
        print(f"> Indexing type: {indexing_type}")
        print(f"> Num hops: {num_hops}")  
        print(f">>>>> Initialization completed")
    
    def select_best_vlm_client(self):
        best_client = random.choice(self.vlm_client_list)
        min_load = float('inf')
        random.shuffle(self.vlm_client_list)
        
        for client in self.vlm_client_list:
            running, waiting = check_busy(client.base_url)
            load = running + waiting
            if load < min_load:
                min_load = load
                best_client = client

        print(f"best_client: {best_client.base_url}")
        return best_client

    def get_text_embedding(self, texts: List[str]) -> Tensor:
        all_embeddings = []
        
        for text in texts:
            emb = get_embedding(text)
            all_embeddings.append(emb)
        
        embeddings = torch.tensor(all_embeddings, dtype=torch.float32, device=torch.device('cpu'))
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings
    
    def retrieve_subnetwork(self, queries: List[str]) -> Tuple[List[Dict], List[Tuple[str, int, Dict]]]:

        if not queries:
            return [], []
        
        query_texts = [f"This is a scene description related to metaphor understanding: {q}" for q in queries]
        
        query_embeddings = self.get_text_embedding(query_texts)  # [num_queries, dim]
        
        similarities = query_embeddings @ self.graph_embeddings.T  # [num_queries, num_nodes]
        
        retrieved_indices = []
        for i in range(len(queries)):
            top_idx = similarities[i].argmax().item()
            retrieved_indices.append((queries[i], top_idx))
        
        unique_indices = list(set([idx for _, idx in retrieved_indices]))
        
        retrieved_nodes = []
        for idx in unique_indices:
            node = self.graph[str(idx)]
            retrieved_nodes.append({
                "node_id": idx,
                "word": node["word"],
                "adjacency": node.get("adjacency", {})
            })
        
        adjacency_counter = Counter() 
        adjacency_info = {}  
        
        for query, idx in retrieved_indices:
            node = self.graph[str(idx)]
            adjacency = node.get("adjacency", {})
            
            for adj_word, adj_info in adjacency.items():
                adjacency_counter[adj_word] += 1
                
                if adj_word not in adjacency_info:
                    adjacency_info[adj_word] = {
                        "source_words": [],
                        "sentences": []
                    }
                
                adjacency_info[adj_word]["source_words"].append(query)
                
                sentences = adj_info.get("sentences", [])
                for sent in sentences:
                    if sent not in adjacency_info[adj_word]["sentences"]:
                        adjacency_info[adj_word]["sentences"].append(sent)
        
        sorted_adjacencies = adjacency_counter.most_common(self.top_k)
        
        top_adjacencies = []
        for adj_word, count in sorted_adjacencies:
            info = adjacency_info[adj_word]
            top_adjacencies.append((adj_word, count, info))
        
        return retrieved_nodes, top_adjacencies
    
    def retrieve_subnetwork_multi_hop(self, queries: List[str]) -> Tuple[List[Dict], List[Tuple[str, int, Dict]]]:

        if not queries:
            return [], []
        
        query_texts = [f"This is a scene description related to metaphor understanding: {q}" for q in queries]
        
        query_embeddings = self.get_text_embedding(query_texts)  # [num_queries, dim]
        
        similarities = query_embeddings @ self.graph_embeddings.T  # [num_queries, num_nodes]
        
        query_node_pairs = []
        for i in range(len(queries)):
            top_idx = similarities[i].argmax().item()
            query_node_pairs.append((queries[i], top_idx))
        
        unique_start_indices = list(set([idx for _, idx in query_node_pairs]))
        
        retrieved_nodes = []
        for idx in unique_start_indices:
            node = self.graph[str(idx)]
            retrieved_nodes.append({
                "node_id": idx,
                "word": node["word"],
                "adjacency": node.get("adjacency", {})
            })
        
        adjacency_counter = Counter() 
        adjacency_info = {} 
        
        for query, start_idx in query_node_pairs:  
            visited = {start_idx: 0}  # node_idx -> hop_distance
            queue = [(start_idx, 0, query)]  # (node_idx, current_hop, source_query)
            
            while queue:
                current_idx, current_hop, source_query = queue.pop(0)
                
                if current_hop < self.num_hops:
                    current_node = self.graph[str(current_idx)]
                    adjacency = current_node.get("adjacency", {})
                    
                    for adj_word, adj_info in adjacency.items():
                        adjacency_counter[adj_word] += 1
                        
                        if adj_word not in adjacency_info:
                            adjacency_info[adj_word] = {
                                "source_words": [],
                                "sentences": [],
                                "min_hop": current_hop + 1  
                            }
                        
                        adjacency_info[adj_word]["source_words"].append(source_query)
                        adjacency_info[adj_word]["min_hop"] = min(
                            adjacency_info[adj_word]["min_hop"], 
                            current_hop + 1
                        )
                        
                        sentences = adj_info.get("sentences", [])
                        for sent in sentences:
                            if sent not in adjacency_info[adj_word]["sentences"]:
                                adjacency_info[adj_word]["sentences"].append(sent)
                        
                        if adj_word in self.word_to_idx:
                            adj_idx = self.word_to_idx[adj_word]
                            if adj_idx not in visited:
                                visited[adj_idx] = current_hop + 1
                                queue.append((adj_idx, current_hop + 1, source_query))
        
        sorted_adjacencies = adjacency_counter.most_common(self.top_k)
        
        top_adjacencies = []
        for adj_word, count in sorted_adjacencies:
            info = adjacency_info[adj_word]
            top_adjacencies.append((adj_word, count, info))
        
        return retrieved_nodes, top_adjacencies

    def format_external_reference(
        self, 
        retrieved_nodes: List[Dict], 
        top_adjacencies: List[Tuple[str, int, Dict]],
        max_examples_per_adj: int = 1
    ) -> str:

        if not top_adjacencies:
            return ""
        
        parts = []
        if self.reference_mode == "simple":
            for adj_word, count, info in top_adjacencies:
                source_words = list(set(info["source_words"]))
                source_words_str = ", ".join([f'"{w}"' for w in source_words])
                entry = f"• The concept {source_words_str} is possiblely associated with \"{adj_word}\""
                parts.append(entry)
        elif self.reference_mode == "examples":
            for adj_word, count, info in top_adjacencies:
                source_words = list(set(info["source_words"])) 
                source_words_str = ", ".join([f'"{w}"' for w in source_words])
                sentences = info["sentences"][:max_examples_per_adj]
                entry = f"• The concept {source_words_str} is possiblely associated with \"{adj_word}\""
                if sentences:
                    entry += "\n  Example usage:"
                    for i, sent in enumerate(sentences, 1):
                        clean_sent = sent.replace('\n', ' ').strip()
                        entry += f"\n    {i}. {clean_sent}"
                parts.append(entry)
        else:
            raise ValueError(f"Unknown reference_mode: {self.reference_mode}")

        result = "\n".join(parts)
        if len(result) > 4000:
            result = result[:4000] + "\n... (more associations omitted)"
        
        return result
    
    def extract_queries(self, image_dir_path: str) -> Tuple[List[str], str, any]:

        print("=" * 50)
        print(f"Step 1: Extracting {'keywords' if self.query_mode == 'keywords' else 'descriptions'} from video")
        print("=" * 50)
        
        if self.query_mode == "keywords":
            prompt = PROMPT_EXTRACT_KEYWORDS
            output_key = "keywords"
        else:  # descriptions
            prompt = PROMPT_EXTRACT_DESCRIPTIONS
            output_key = "descriptions"
        
        output, total_output, saved_image_messages = self.vlm_client.call_openai_vl(
            prompt=prompt,
            image_dir_path=image_dir_path,
            return_image_messages=True
        )
        
        output_json = self.get_clear_output_json(output)
        
        assert output_key in output_json, f'"{output_key}" not in output_json: {output_json}'
        assert isinstance(output_json[output_key], list), f'{output_key} is not a list: {output_json}'
        
        queries = output_json[output_key]
        print(f"Extracted {len(queries)} {self.query_mode}:")
        for i, q in enumerate(queries):
            print(f"  {i+1}. {q}")
        
        return queries, output, saved_image_messages
    
    def final_analysis(
        self,
        external_reference: str,
        saved_image_messages: any,
        image_dir_path: str,
        title: str
    ) -> Tuple[Dict, str]:

        print("=" * 50)
        print("Step 3: Final metaphor analysis with external reference")
        print("=" * 50)
        
        prompt = PROMPT_FINAL_ANALYSIS.format(external_reference=external_reference, title=title)
        
        if saved_image_messages is not None:
            output, total_output = self.vlm_client.call_openai_vl(
                prompt=prompt,
                saved_image_messages=saved_image_messages
            )
        else:
            output, total_output = self.vlm_client.call_openai_vl(
                prompt=prompt,
                image_dir_path=image_dir_path
            )
        
        output_json = self.get_clear_output_json(output)
        
        assert "analysis_dict" in output_json, f'"analysis_dict" not in output_json: {output_json}'
        
        analysis_dict = output_json["analysis_dict"]
        print(f"Generated {len(analysis_dict)} analysis entries")
        
        return analysis_dict, output
    
    def process(self, image_dir_path: str, title: str) -> Tuple[Dict, Dict]:

        process_record = {
            "config": {
                "aug_mode": self.aug_mode,
                "query_mode": self.query_mode,
                "reference_mode": self.reference_mode,
                "top_k": self.top_k
            }
        }
        
        for retry in range(self.max_retry):
            print(f"\n{'#' * 60}")
            print(f"Process Attempt {retry + 1}/{self.max_retry}")
            print(f"Config: aug_mode={self.aug_mode}, query_mode={self.query_mode}, reference_mode={self.reference_mode}")
            print(f"{'#' * 60}\n")
            
            self.vlm_client = self.select_best_vlm_client()

            try:
                queries, queries_output, saved_image_messages = self.extract_queries(image_dir_path)
                process_record["step1_extraction"] = {
                    "mode": self.query_mode,
                    "queries": queries,
                    "original_output": queries_output
                }
                
                print("=" * 50)
                print(f"Step 2: Do augmentation using {self.aug_mode}")
                print("=" * 50)

                if self.top_k > 0:
                    if self.aug_mode == 'mkg':
                        if self.num_hops > 1:
                            retrieved_nodes, top_adjacencies = self.retrieve_subnetwork_multi_hop(queries)
                        else:
                            retrieved_nodes, top_adjacencies = self.retrieve_subnetwork(queries)
                        external_reference = self.format_external_reference(retrieved_nodes, top_adjacencies)                    
                        process_record["step2_retrieval"] = {
                            "num_retrieved_nodes": len(retrieved_nodes),
                            "retrieved_nodes": retrieved_nodes,
                            "top_adjacencies": [(adj, cnt, {"source_words": info["source_words"], "num_sentences": len(info["sentences"])}) for adj, cnt, info in top_adjacencies],
                            "external_reference": external_reference
                        }
                    elif self.aug_mode == 'ckg':
                        external_reference = self.aug_ckg(queries)
                        process_record["step2_retrieval"] = {"external_reference": external_reference}
                    elif self.aug_mode == 'rag':
                        external_reference = self.aug_rag(queries)
                        process_record["step2_retrieval"] = {"external_reference": external_reference}
                    elif self.aug_mode == 'self':
                        external_reference = self.aug_self(queries)
                        process_record["step2_retrieval"] = {"external_reference": external_reference}
                    else:
                        raise ValueError(f"Unknown aug_mode: {self.aug_mode}")
                else:
                    external_reference = ""
                    
                print(f"\nExternal reference:\n{external_reference}")

                analysis_dict, analysis_output = self.final_analysis(
                    external_reference=external_reference, 
                    saved_image_messages=saved_image_messages, 
                    image_dir_path=image_dir_path,
                    title=title
                )
                
                process_record["step3_analysis"] = {
                    "analysis_dict": analysis_dict,
                    "original_output": analysis_output
                }
                
                final_answer = {"analysis_dict": analysis_dict}
                
                print("=" * 50)
                print("Process completed successfully!")
                print(f"Final answer: {json.dumps(final_answer, ensure_ascii=False, indent=2)}")
                print("=" * 50)
                
                return final_answer, process_record
            
            except Exception as e:
                print(f"Error in process attempt {retry + 1}: {e}")
                process_record[f"error_attempt_{retry + 1}"] = str(e)
                continue
        
        print("=" * 50)
        print("Process failed after all retries!")
        print("=" * 50)
        
        return "error", process_record
    
    def get_clear_output_json(self, output: str) -> Dict:
        if "</think>" in output:
            output = output.split("</think>")[-1]
        if output.startswith('<|begin_of_box|>'):
            output = output.replace('<|begin_of_box|>', '').replace('<|end_of_box|>', '')
        
        output_clear = output.split("```json")[-1].split("```")[0]
        output_clear = ''.join(char for char in output_clear if ord(char) >= 32 or char in '\n\r\t')
        
        output_json = json.loads(output_clear)
        
        return output_json

    def aug_self(self, queries: List[str]) -> str:        
        prompt = f"Please generate metaphorical associations for the following concepts. {queries}"
        
        try:
            result, _ = self.vlm_client.call_openai(prompt)
            if "</think>" in output:
                output = output.split("</think>")[-1]
        except Exception as e:
            print(f"Error calling LLM for self-association: {e}")
            parts = []
            for q in queries[:self.top_k]:
                parts.append(f"• The concept \"{q}\" may symbolize deeper meanings related to emotions, experiences, or abstract ideas.")
            result = "\n".join(parts)
        
        return result
    
    def aug_rag(self, queries: List[str]) -> str:

        query_sentence = "A video containing: " + ", ".join(queries)
        
        query_embedding = self.get_text_embedding([query_sentence])  # [1, dim]
        
        similarities = query_embedding @ self.sentences_embeddings.T  # [1, num_sentences]
        similarities = similarities.squeeze(0)  # [num_sentences]
        
        top_k_values, top_k_indices = similarities.topk(self.top_k)
        
        parts = []
        parts.append(f"Here are {self.top_k} relevant metaphorical expressions found in the knowledge base:\n")
        
        for i, (idx, score) in enumerate(zip(top_k_indices.tolist(), top_k_values.tolist())):
            sentence_data = self.sentences_data[idx]
            text = sentence_data.get("text", "")
            clean_text = text.replace('\n', ' ').strip()
            
            parts.append(f"{i+1}. {clean_text}")
        
        result = "\n".join(parts)
                
        return result

    def aug_ckg(self, queries: List[str]) -> str:
        
        all_results = []
        for keyword in queries:
            result = self.ckg_client.get_profile(keyword, relations=["IsA", "RelatedTo"])
            if result:
                all_results.append(self.ckg_client.return_profile(result))
        return all_results


if __name__ == "__main__":
    from use_vllm_api import VLLMClient
    from use_wanqing_api import WangQingClient
    from kg.use import ConceptNetClient
    
    vlm_client = WangQingClient(model_name="Doubao-1_5-Vision-Pro-250328")
    ckg_client = ConceptNetClient()
    
    image_dir_path = "/XXX/benchmark/videos/118540876747"
    title = "Test Video"
    
    print("\n" + "=" * 80)
    print("Testing: aug_mode=mkg")
    print("=" * 80)
    
    boost_mkg = MetaphorVU_Boost(
        vlm_client=vlm_client,
        ckg_client=ckg_client,
        top_k=10,
        max_retry=5,
        aug_mode="mkg",
        query_mode="keywords",
        reference_mode="simple"
    )
    final_answer, process_record = boost_mkg.process(image_dir_path, title)
    print(f"Result: {final_answer}")
    
    print("\n" + "=" * 80)
    print("Testing: aug_mode=self")
    print("=" * 80)
    
    boost_self = MetaphorVU_Boost(
        vlm_client=vlm_client,
        ckg_client=ckg_client,
        top_k=10,
        max_retry=5,
        aug_mode="self",
        query_mode="keywords",
        reference_mode="simple"
    )
    final_answer, process_record = boost_self.process(image_dir_path, title)
    print(f"Result: {final_answer}")
    
    print("\n" + "=" * 80)
    print("Testing: aug_mode=rag")
    print("=" * 80)
    
    boost_rag = MetaphorVU_Boost(
        vlm_client=vlm_client,
        ckg_client=ckg_client,
        top_k=10,
        max_retry=5,
        aug_mode="rag",
        query_mode="keywords",
        reference_mode="simple"
    )
    final_answer, process_record = boost_rag.process(image_dir_path, title)
    print(f"Result: {final_answer}")


    print("\n" + "=" * 80)
    print("Testing: aug_mode=ckg")
    print("=" * 80)
    
    boost_rag = MetaphorVU_Boost(
        vlm_client=vlm_client,
        ckg_client=ckg_client,
        top_k=10,
        max_retry=5,
        aug_mode="ckg",
        query_mode="keywords",
        reference_mode="simple"
    )
    final_answer, process_record = boost_rag.process(image_dir_path, title)
    print(f"Result: {final_answer}")
