import time
import requests
import json
import torch
import os
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from core.config import load_cfg
from openai import OpenAI
from transformers import BertTokenizer, BertModel
# from core.llm.baselines_lm.Qwen import Qwen


class OllamaPromptGenerator:
    def __init__(self, cfg, gat_model):
        self.cfg = cfg
        self.model_name = cfg.llm.prompt_generator.model_name
        self.api_url = cfg.llm.prompt_generator.api_url
        self.headers = {
            "Content-Type": "application/json",
            "Accept": "application/json",
            "X-Stream": "false"
        }
        self.max_workers = cfg.llm.prompt_generator.max_workers
        self.gat_model = gat_model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.bert_model.to("cuda")
        if self.cfg.llm.prompt_generator.model_name == "qwen":
            self.qwen = Qwen(model_name="Qwen/Qwen2.5-7B-Instruct", classification=None, device="cuda")
        self.cora_labels = [
            "Neural_Networks",
            "Rule_Learning",
            "Reinforcement_Learning",
            "Probabilistic_Methods",
            "Theory",
            "Genetic_Algorithms",
            "Case_Based"
        ]
        logging.basicConfig(level=logging.INFO)

    def _get_label_examples(self):
        examples = {
            "Neural_Networks": "Transformer architectures with self-attention mechanisms",
            "Rule_Learning": "Inductive logic programming systems",
            "Reinforcement_Learning": "Q-learning algorithms with deep networks",
            "Probabilistic_Methods": "Bayesian belief networks",
            "Theory": "Computational complexity analysis",
            "Genetic_Algorithms": "Evolutionary optimization techniques",
            "Case_Based": "Instance-based reasoning systems"
        }
        return "\n".join([f"Example({label}): {desc}" for label, desc in examples.items()])

    def _get_label_definitions(self):
        definitions = {
            "Neural_Networks": "Artificial neural networks and deep learning architectures",
            "Rule_Learning": "Machine learning using symbolic rule systems",
            "Reinforcement_Learning": "Learning through environmental rewards/punishments",
            "Probabilistic_Methods": "Statistical and probabilistic learning methods",
            "Theory": "Theoretical computer science foundations",
            "Genetic_Algorithms": "Evolutionary computation techniques",
            "Case_Based": "Reasoning by analogy to past cases"
        }
        return "\n".join([f"- {label}: {desc}" for label, desc in definitions.items()])

    def _calculate_text_similarity(self, text1, text2):
        inputs = self.tokenizer([text1, text2], return_tensors='pt', padding=True, truncation=True)
        inputs = {k: v.to("cuda") for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.bert_model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu()
        similarity = cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0))
        return similarity[0][0]

    def _extract_keywords(self, texts, top_n=5):
        vectorizer = TfidfVectorizer(stop_words="english", max_features=1000)
        X = vectorizer.fit_transform(texts)
        feature_array = vectorizer.get_feature_names_out()
        tfidf_sorting = X.toarray().sum(axis=0).argsort()[::-1]
        top_keywords = [feature_array[i] for i in tfidf_sorting[:top_n]]
        return ", ".join(top_keywords)

    def _get_neighbor_weights(self, node_idx, edge_index, node_text, neighbor_texts, cfg):
        with torch.no_grad():
            x = torch.arange(cfg.gnn.vocab_size).to(cfg.training.device)
            edge_index = edge_index.to("cuda")
            neighbors = edge_index[1][edge_index[0] == node_idx]
            _, (_, _, attn_weights) = self.gat_model(edge_index, return_attention_weights=True)
            target_edges = (edge_index[1] == node_idx)
            neighbor_weights = attn_weights[1][:edge_index.size(1)][target_edges][:, 0]

            similarity_scores = [self._calculate_text_similarity(node_text, neighbor_texts[i]) for i in
                                 neighbors.tolist()]

            if cfg.dataset.source_name == "history":
                abstract = re.search(r"Description:\s*(.*)", node_text, re.MULTILINE | re.DOTALL)
            else:
                abstract = re.search(r"Abstract:\s*(.*)", node_text, re.MULTILINE | re.DOTALL)
            abstract = abstract.group(1) if abstract else "No abstract available."
            abstract_length = len(abstract.split())
            if abstract_length < 100:
                gat_weight, sim_weight = 0.7, 0.3
            else:
                gat_weight, sim_weight = 0.4, 0.6

            combined_weights = [
                (n, w * gat_weight + s * sim_weight)
                for (n, w), s in zip(zip(neighbors.tolist(), neighbor_weights.tolist()), similarity_scores)
            ]
            combined_weights.sort(key=lambda x: x[1], reverse=True)
            node_degree = (edge_index[0] == node_idx).sum().item()
            return combined_weights

    def _generate_dynamic_prompt_strategy(self, node_text, neighbor_texts, cfg, node_degree, combined_weights):
        if cfg.dataset.source_name == "history":
            abstract_match = re.search(r"Description:\s*(.*)", node_text, re.MULTILINE | re.DOTALL)
        else:
            abstract_match = re.search(r"Abstract:\s*(.*)", node_text, re.MULTILINE | re.DOTALL)
        abstract_len = len(abstract_match.group(1).split()) if abstract_match else 0

        if node_degree < 3:
            dynamic_part = f"Note: This paper has only {node_degree} references. Pay more attention to its own content."
        else:
            dynamic_part = f"This paper is well-connected with {node_degree} references."

        if abstract_len > 150:
            dynamic_part += "\nFocus on the detailed methodology in the abstract."
        else:
            dynamic_part += "\nFocus on the key contributions."

        return dynamic_part

    def safe_parse_response(self, response_text):
        try:
            return json.loads(response_text)
        except json.JSONDecodeError:
            last_valid = response_text.rfind('{')
            return json.loads(response_text[last_valid:])

    def _generate_single_prompt(self, node_text, neighbor_texts, edge_index, node_idx, label, cfg, max_retries=3):
        node_degree = (edge_index[0] == node_idx).sum().item()
        node_degree += (edge_index[1] == node_idx).sum().item()
        neighbor_weight_pairs = self._get_neighbor_weights(node_idx, edge_index, node_text, neighbor_texts, cfg)
        top_neighbors = neighbor_weight_pairs[:3]
        top_neighbor_indices = [neighbor[0] for neighbor in top_neighbors]
        top_neighbor_texts = [neighbor_texts[i] for i in top_neighbor_indices]

        title = re.search(r"Title:\s*(.*?)\n", node_text, re.MULTILINE)
        if cfg.dataset.source_name == "history":
            abstract = re.search(r"Description:\s*(.*)", node_text, re.MULTILINE | re.DOTALL)
        else:
            abstract = re.search(r"Abstract:\s*(.*)", node_text, re.MULTILINE | re.DOTALL)
        title = title.group(1) if title else "Unknown Title"
        abstract = abstract.group(1) if abstract else "No abstract available."
        all_texts = [node_text] + top_neighbor_texts
        common_keywords = self._extract_keywords(all_texts, top_n=3)
        neighbor_info = []
        for neighbor_text in top_neighbor_texts:
            n_title = re.search(r"Title:\s*(.*?)\n", neighbor_text, re.MULTILINE)
            if cfg.dataset.source_name == "history":
                n_abstract = re.search(r"Description:\s*(.*)", neighbor_text, re.MULTILINE | re.DOTALL)
            else:
                n_abstract = re.search(r"Abstract:\s*(.*)", neighbor_text, re.MULTILINE | re.DOTALL)
            n_title = n_title.group(1) if n_title else "Unknown Title"
            n_abstract = n_abstract.group(1) if n_abstract else "No abstract available."
            neighbor_info.append(f"cited paper title: '{n_title}', abstract: {n_abstract}\n")
        neighbor_weight_pairs = self._get_neighbor_weights(node_idx, edge_index, node_text, neighbor_texts, cfg)
        dynamic_components = self._generate_dynamic_prompt_strategy(
            node_text, top_neighbor_texts, cfg, node_degree, neighbor_weight_pairs
        )

        if cfg.dataset.source_name == "wikics" or cfg.dataset.source_name == "photo":
            prompt = f""" Input: Node and Neighbor Information: {all_texts}, Common Topics: {common_keywords},{dynamic_components}, Key References: {chr(10).join(neighbor_info)}
             Graph: Connections: {node_degree}, Key Neighbors: {', '.join([f'weight={w:.2f}' for _, w in neighbor_weight_pairs[:3]])}, Common Themes: {common_keywords}
             Output: Question: Which of the following sub-categories of AI does this paper belong to: {str(label)}? \nPlease comprehensively consider the information from
                the article and its neighbors, provide a comma-separated list ordered from most to least related, and only return the categories words without other words.
            """
        else:
            prompt = f"""
"Classification Prediction:
Abstract: {abstract}
Title: {title}, {dynamic_components},
Key References: {chr(10).join(neighbor_info)}, Graph Structure: 1. Connections: {node_degree}; 2. Key Neighbors: {', '.join([f'weight={w:.2f}' for _, w in neighbor_weight_pairs[:3]])},
Common Research Themes: {common_keywords}
Question: Which of the following sub-categories does this paper belong to: {str(label)}?"
Example:
 Input paper about "Graph Neural Networks for Drug Discovery"
Output: {{
"primary_category": "Graph AI",
"secondary_categories": ["Bioinformatics", "Deep Learning"],
 "confidence_score": 0.92,
 "decision_rationale": "Combines GNN architectures with molecular graph processing"
 }}
  """
        for attempt in range(max_retries):
            try:
                if cfg.llm.prompt_generator.model_name == "llama3.2":
                    response = requests.post(
                        self.api_url,
                        headers=self.headers,
                        json={
                            "model": self.model_name,
                            "prompt": prompt,
                            "stream": False,
                            "max_tokens": cfg.llm.prompt_generator.max_tokens,
                            "temperature": cfg.llm.prompt_generator.temperature,
                            "top_k": cfg.llm.prompt_generator.top_k,
                            "top_p": cfg.llm.prompt_generator.top_p
                        },
                        timeout=cfg.llm.prompt_generator.timeout
                    )
                    response.raise_for_status()
                    return self.safe_parse_response(response.text).get("response")
                elif self.cfg.llm.prompt_generator.model_name == "deepseek-chat":
                    response = requests.post(
                        self.api_url,
                        headers={
                            "Content-Type": "application/json",
                            "Authorization": f"Bearer {cfg.llm.prompt_generator.api_key}"
                        },
                        json={
                            "model": self.model_name,
                            "messages": [{
                                "role": "user",
                                "content": prompt
                            }],
                            "temperature": cfg.llm.prompt_generator.temperature,
                            "max_tokens": cfg.llm.prompt_generator.max_tokens,
                            "top_p": cfg.llm.prompt_generator.top_p
                        },
                        timeout=cfg.llm.prompt_generator.timeout
                    )
                    return response.json()["choices"][0]["message"]["content"]

                elif self.cfg.llm.prompt_generator.model_name == "gpt-3.5-turbo":
                    client = OpenAI(api_key=cfg.llm.prompt_generator.api_key)
                    response = client.chat.completions.create(
                        model=self.model_name,
                        messages=[
                            {"role": "system", "content": "You are a helpful assistant."},
                            {"role": "user", "content": prompt}
                        ],
                        max_tokens=cfg.llm.prompt_generator.max_tokens,
                        temperature=cfg.llm.prompt_generator.temperature,
                        top_p=cfg.llm.prompt_generator.top_p
                    )
                    return response.choices[0].message.content


                elif self.cfg.llm.prompt_generator.model_name == "gpt-4o":
                    client = OpenAI(api_key=cfg.llm.prompt_generator.api_key)
                    response = client.chat.completions.create(
                        model=self.model_name,
                        messages=[
                            {"role": "system", "content": "You are a helpful assistant."},
                            {"role": "user", "content": prompt}
                        ],
                        max_tokens=cfg.llm.prompt_generator.max_tokens,
                        temperature=cfg.llm.prompt_generator.temperature,
                        top_p=cfg.llm.prompt_generator.top_p
                    )
                    return response.choices[0].message.content

                elif self.cfg.llm.prompt_generator.model_name == "qwen":
                    test_messages = [
                        {"role": "system", "content": "You are a helpful AI assistant"},
                        {"role": "user", "content": prompt}
                    ]
                    response = self.qwen.generate_response(test_messages, top_p=cfg.llm.prompt_generator.top_p,
                                                           temperature=cfg.llm.prompt_generator.temperature,
                                                           max_new_tokens=cfg.llm.prompt_generator.max_tokens)
                    return response
            except Exception as e:
                logging.warning(f"Attempt {attempt + 1} failed: {e}")
                if attempt < max_retries - 1:
                    time.sleep(cfg.llm.prompt_generator.retry_delay)
                else:
                    raise Exception(f"Request failed after {max_retries} attempts: {e}")

    def generate_batch_prompts(self, node_texts, neighbor_texts_list, edge_index, label, cfg):
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = [
                executor.submit(self._generate_single_prompt, node_text, neighbor_texts, edge_index, idx, label, cfg)
                for idx, (node_text, neighbor_texts) in enumerate(zip(node_texts, neighbor_texts_list))
            ]
            return [future.result() for future in futures]

    def load_prompts_from_batches(self, cfg):
        output_file = cfg.llm.output_file
        batch_size = cfg.llm.batch_size
        save_dir = cfg.llm.save_dir
        batch_index_step = cfg.llm.batch_index_step

        prompts = []
        batch_index = 0

        while True:
            batch_output_file = os.path.join(save_dir, f"prompts_batch_{batch_index}.json")

            if not os.path.exists(batch_output_file):
                print(f"Batch file {batch_output_file} does not exist. Stopping loading.")
                break

            try:
                with open(batch_output_file, "r", encoding="utf-8") as f:
                    batch_prompts = json.load(f)
                prompts.extend(batch_prompts)
                print(f"Loaded {len(batch_prompts)} prompts from batch {batch_index}.")
            except json.JSONDecodeError as e:
                print(f"Failed to decode JSON from {batch_output_file}: {e}")
                raise Exception(f"Failed to load batch {batch_index}: {e}")
            except Exception as e:
                print(f"Unexpected error while loading batch {batch_index}: {e}")
                raise Exception(f"Failed to load batch {batch_index}: {e}")

            batch_index += batch_index_step

        if prompts:
            try:
                with open(output_file, "w", encoding="utf-8") as f:
                    json.dump(prompts, f, ensure_ascii=False, indent=4)
                print(f"Loaded {len(prompts)} prompts from {batch_index} batches and saved to {output_file}")
            except Exception as e:
                print(f"Failed to save prompts to {output_file}: {e}")
                raise Exception(f"Failed to save prompts: {e}")

        return prompts

    def generate_and_save_prompts(self, data, save_dir):
        output_file = self.cfg.llm.output_file
        batch_size = self.cfg.llm.batch_size

        os.makedirs(save_dir, exist_ok=True)

        full_output_file = os.path.join(save_dir, output_file)
        print(f"Checking for existing prompts file at: {full_output_file}")
        import json
        if os.path.exists(full_output_file):
            print(f"Loading prompts from existing file: {full_output_file}")
            with open(full_output_file, "r", encoding="utf-8") as f:
                prompts = json.load(f)
            print(f"Loaded {len(prompts)} prompts from {full_output_file}")
            try:
                return [i.get("response") for i in prompts]
            except Exception as e:
                return prompts

        print("No existing prompts found. Generating new prompts...")
        print(data)
        print(data.edge_index)

        edge_index = data.edge_index
        num_nodes = len(data.x_text)
        node_texts = []
        neighbor_texts_list = []

        for node_idx in range(num_nodes):
            node_text = data.x_text[node_idx]
            if self.cfg.dataset.source_name == "ogbn-products" or self.cfg.dataset.source_name == "ogbn-arxiv":
                row, col, _ = edge_index.coo()
                neighbors = col[row == node_idx].tolist()
            elif self.cfg.dataset.source_name == "arxiv_2023":
                nonzero_indices = torch.nonzero(edge_index)
                row = nonzero_indices[:, 0]
                col = nonzero_indices[:, 1]
                neighbors = col[row == node_idx].tolist()
            else:
                neighbors = edge_index[1][edge_index[0] == node_idx].tolist()
            neighbor_texts = data.x_text
            node_texts.append(node_text)
            neighbor_texts_list.append(neighbor_texts)
            print(f"Collecting data for node {node_idx + 1}/{num_nodes}")

        prompts = []
        for i in range(0, num_nodes, batch_size):
            batch_output_file = os.path.join(save_dir, f"prompts_batch_{i}.json")

            if os.path.exists(batch_output_file):
                print(f"Batch {i // batch_size + 1} already exists at {batch_output_file}, skipping...")
                continue

            batch_node_texts = node_texts[i:i + batch_size]
            batch_neighbor_texts = neighbor_texts_list[i:i + batch_size]
            print(f"Generating prompts for batch {i // batch_size + 1}...")

            batch_prompts = self.generate_batch_prompts(
                batch_node_texts,
                batch_neighbor_texts,
                data.edge_index,
                data.label,
                self.cfg
            )

            prompts.extend(batch_prompts)

            with open(batch_output_file, "w", encoding="utf-8") as f:
                json.dump(batch_prompts, f, ensure_ascii=False, indent=4)

            print(f"Prompts for batch {i // batch_size + 1} saved to {batch_output_file}")

        import json
        import glob

        file_pattern = os.path.join(save_dir, "prompts_batch_*.json")
        batch_files = sorted(glob.glob(file_pattern))

        merged_data = []
        for file_path in batch_files:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                merged_data.extend(data)

        with open(full_output_file, 'w', encoding='utf-8') as f:
            json.dump(merged_data, f, indent=2, ensure_ascii=False)
        print(f"All prompts saved to {full_output_file}")

        return prompts