from typing import List, Dict, Optional
from pydantic import BaseModel
import json
from models.entity import Entity
from models.relationship import Relationship
from openai import OpenAI
import os
from prompt.Query import QUERY_PROMPT_BASE
from utils.rag import search_docs, get_data
import pandas as pd
from retriever.wag_retriever import *
import tiktoken

dataset_path = {
    "ifh_affect": "ifh_df",
    "globem": "globem_df",
    "lifesnap": "lifesnap_df",
    "pmdata": "pmdata_df",
}

class Base():
    """Base class for query operations"""
    def __init__(
            self,
         
            root_dir: str = "resources/kg",
            
            chat_client: str = "deepseek",
            embed_client: str = "openai",
            query_prompt: str = QUERY_PROMPT_BASE,
            param: dict = {},
            
    ):  
        # load the graph 
        self.root_dir = root_dir
        self.param = param
    
        with open(f"{root_dir}/nodes_with_embeddings.json", "r") as f:
            self.nodes_with_embeddings = json.load(f)
        for node in self.nodes_with_embeddings:
            self.nodes_with_embeddings[node] = Entity.from_dict(self.nodes_with_embeddings[node])
        with open(f"{root_dir}/edges.json", "r") as f:
            self.edges = json.load(f)
        for edge in self.edges:
            self.edges[edge] = Relationship.from_dict(self.edges[edge])

        #load prior matrix
        self.prior_matrix = get_prior_matrix(root_dir)
    
        # define the chat client
        if chat_client == "deepseek":
            self.chat_client = OpenAI(api_key=os.getenv('DEEPSEEK_API_KEY'), base_url="https://api.deepseek.com")
            self.chat_model = "deepseek-chat"
            # self.chat_model = "deepseek-reasoner"
        elif chat_client == "openai":
            self.chat_client = OpenAI()
            self.chat_model = "gpt-4o"
        else:
            raise ValueError("Invalid chat client")

        # define the embed client
        if embed_client == "openai":
            self.embed_client = OpenAI()
            self.embed_model = "gpt-4o-mini"
        else:
            raise ValueError("Invalid embed client")

        # Define the prompt
        self.prompt = query_prompt
        # Define the functions that the LLM can use

        self.functions = [
            {
                "type": "function",
                "function": {
                    "name": "query_parse",
                    "description": "Extract important entities from the provided query.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "metrics": {
                                "type": "array",
                                "description": "List of key metrics found in the query",
                                "items": {"type": "string"}
                            },  # Added missing comma
          
                            "time_range": {
                                "type": "string",
                                "description": "number of days till the start date, [1, 7, 30,..., all time]",
                            },
                            "openness_score": {
                                "type": "number",
                                "description": "Openness score of the question",
   
                            }
                        },
                        "required": ["entities", "time_range", "openness_score"],
                        "additionalProperties": False
                    },
                    "strict": True
                }
            }
            
        ]
        
    def llm_inference(self, messages, tools, attempts = 2):
        """Try LLM inference twice before returning an error."""
        for attempt in range(attempts):
            try:
                response = self.chat_client.chat.completions.create(
                    model=self.chat_model,
                    messages=messages,
                    temperature=0.1,
                    tools=tools,
                    tool_choice={"type": "function", "function": {"name": "query_parse"}}
                )
                
                if response:# and response.choices and response.choices[0].message.tool_calls:
                    return response
                else:
                    print(f"Attempt {attempt+1} failed: Empty or invalid response.")
            except Exception as e:
                print(f"Attempt {attempt+1} failed: {e}")

        raise RuntimeError("LLM inference failed after two attempts.")
    

    def query_quick(self, user_query: str, par_df: pd.DataFrame, gt_nodes: List[str], query_info: dict, **kwargs):
        """Thread-safe query processing with optimized data loading and error handling"""
        # try:
            # Initialize output structure
        output = {
            'entities': None,
            'start_date': None,
            'time_range': None,
            'openness_score': None,
            'matched_entities': None,
            'related_entities': None,
            'feature_context': None,
            'edge_context': None,
            'full_context': None,
            'data': None,
            'entity_dict': None,
            'response': None,
            'error': None
        }



        prompt = self.prompt#.format(context=query_info['query_date'])
        

        self.par_df = par_df.copy()

        self.dataset_name = kwargs['dataset_name']
       
        # return
        # Knowledge Graph Search
        result_dict, out_df = self.search_knowledge_graph(
            gt_nodes, 
            query_info, 
            **kwargs
        )

        # matched_entities, related_entities, related_relationships, entity_dict = search_results
        
        # Context Building
        context, data = self.build_context(
            result_dict,
            query_info['query_date'], 
            query_info['time_granularity']
        )

        final_response = None
        # print(prompt)

        text = 'Query: ' + user_query + '\nContext: \n' + context +'\n Answer:'

        # Get the tokenizer for a specific model (e.g., "gpt-4")
        encoding = tiktoken.encoding_for_model("gpt-4")

    
        if kwargs['response_gen']:
            messages = [
                {"role": "system", "content": prompt},
                {"role": "user", "content": 'Query: ' + user_query + '\n Context: ' + context +'\n Answer:'}
            ]
            response2 = self.llm_inference(messages=messages, tools=None)
            final_response = response2.choices[0].message.content
            
        # Final LLM Call
        # response2 = self.llm_inference(messages=messages_new, tools=None)
        
        # # Update final output
        output.update({
            'result_dict': result_dict,
            'context': context,
            'data': data,
            'response': final_response
        })

        return output, out_df
    
    def query_parse(self, user_query: str, **kwargs):
        """Thread-safe query processing with optimized data loading and error handling"""
        # try:
            # Initialize output structure
        output = {
            'entities': None,
            'start_date': None,
            'time_range': None,
            'openness_score': None,
            'matched_entities': None,
            'related_entities': None,
            'feature_context': None,
            'edge_context': None,
            'full_context': None,
            'data': None,
            'entity_dict': None,
            'response': None,
            'error': None
        }

        # self.par_df = par_df.copy()
   
        prompt = self.prompt
        messages = [
            {"role": "system", "content": prompt},
            {"role": "user", "content": user_query}
        ]

        response = self.llm_inference(messages=messages, tools=self.functions)
        if not response.choices[0].message.tool_calls:
            raise ValueError("No function call in LLM response")

        tool_call = response.choices[0].message.tool_calls[0]
        args = json.loads(tool_call.function.arguments)
        print(args)
        # Update output with initial results
        output.update({
            'metrics': args['metrics'],
            'time_range': args['time_range'],
            'openness_score': args['openness_score']
        })
        query_info = {
            'time_granularity': args['time_range'],
            'openness': args['openness_score']
        }

        result_dict = {}
        out = []
        for entity in args['metrics']:
            res = search_docs(self.nodes_with_embeddings, entity.capitalize(), top_n=1)
            node, sim = res[0]['node'],res[0]['similarity']
            # print(res)
            # if sim > 0.8:
            out.append((node.name, sim))


        return out

    
    def query(self, user_query: str, par_df: pd.DataFrame, query_date: str, **kwargs):
        """Thread-safe query processing with optimized data loading and error handling"""
        # try:
            # Initialize output structure
        output = {
            'entities': None,
            'start_date': None,
            'time_range': None,
            'openness_score': None,
            'matched_entities': None,
            'related_entities': None,
            'feature_context': None,
            'edge_context': None,
            'full_context': None,
            'data': None,
            'entity_dict': None,
            'response': None,
            'error': None
        }

        self.par_df = par_df.copy()

        # LLM First Call
        prompt = self.prompt.format(today_date=query_date)
        messages = [
            {"role": "system", "content": prompt},
            {"role": "user", "content": user_query}
        ]

        response = self.llm_inference(messages=messages, tools=self.functions)
        if not response.choices[0].message.tool_calls:
            raise ValueError("No function call in LLM response")

        tool_call = response.choices[0].message.tool_calls[0]
        args = json.loads(tool_call.function.arguments)
        
        # Update output with initial results
        output.update({
            'entities': args['entities'],
            'start_date': args['start_date'],
            'time_range': args['time_range'],
            'openness_score': args['openness_score']
        })
        query_info = {
            'query_date': args['start_date'],
            'time_granularity': args['time_range'],
            'openness': args['openness_score']
        }



        result_dict = self.search_knowledge_graph(
            args['entities'], 
            query_info, 
            **kwargs
        )


        # Context Building
        context, data, context_dict = self.build_context(
            result_dict,
            args['start_date'], 
            args['time_range']
        )
    

        # Prepare second LLM call
        # full_context = feature_context + edge_context
        final_response = None

            
        # Update final output
        output.update({
            'result_dict': result_dict,
            'context': context,
            'data': data,
            'response': final_response
        })

        return output


    def search_knowledge_graph(self, entities: List[str], query_info, **kwargs):
        """Search the knowledge graph for the given entities and time range"""

        result_dict = {}
        for entity in entities:
            # res = search_docs(self.nodes_with_embeddings, entity, top_n=self.param['node_match_top_n'])
            res = []
        
            for id, node in self.nodes_with_embeddings.items():
    
                if node.name == entity:
                    res.append({'node': node, 'similarity': 1.0})
                    break
            r = res[0]
            if entity not in result_dict:
                result_dict[entity] = {'similarity': r['similarity'], 'primary_node':[r['node'], None], 'related_nodes':[], 'demog_nodes':[]}
                # matched_nodes.append(r)
            else:
                #update the entity with higher similarity
                if r['similarity'] > result_dict[entity]['similarity']:
                    result_dict[entity]['similarity'] = r['similarity']
                    result_dict[entity]['primary_node'] = [r['node'], None]
        return result_dict, None





    def build_context(
        self,
        result_dict,
        start_date: str,
        time_range: str,    
    ):  
        primary_context = ''
        related_context = ''
        demog_context = ''
        data = {}
        # data_root_path = self.user_path
        primary_nodes = {}
        related_nodes = {}
        demog_nodes = {}
        for entity, inner_dict in result_dict.items():
            primary_node = inner_dict['primary_node'][0]
            abnormality = inner_dict['primary_node'][1]
            primary_nodes[primary_node.name] = (primary_node,abnormality)
            for related_node, edge, weight, abnormality in inner_dict['related_nodes']:
                if related_node.name not in related_nodes:
                    related_nodes[related_node.name] = (related_node, [(edge, weight, abnormality, primary_node.name)])
                else:
                    related_nodes[related_node.name][1].append((edge, weight, abnormality, primary_node.name))
            for demog_node, edge, weight in inner_dict['demog_nodes']:
                if demog_node.name not in demog_nodes:
                    demog_nodes[demog_node.name] = (demog_node, [(edge, weight, primary_node.name)])
                else:
                    demog_nodes[demog_node.name][1].append((edge, weight, primary_node.name))

        #generate context 
        for primary_node, abnormality in primary_nodes.values():
            primary_node_context, primary_node_data = self.build_node_context(primary_node, start_date, time_range, abnormality)
            primary_context += f"{primary_node_context}\n"
            data.update(primary_node_data)

        for related_node_name, (related_node, edges) in related_nodes.items():
            for edge, weight, abnormality, primary_node_name in edges:
             
                edge_context = self.build_edge_context(edge, related_node_name, primary_node_name, weight)
                related_context += f"{edge_context}\n"
            related_node_context, related_node_data = self.build_node_context(related_node, start_date, time_range, abnormality)
            related_context += f"{related_node_context}\n"
            data.update(related_node_data)
            
                # context += f"Recent Abnormality: No recent abnormality score\n"

        #build demog nodes context
        for demog_node_name, (demog_node, edges) in demog_nodes.items():
            demog_node_context = self.build_demog_node_context(demog_node)
            # print('>>>>>>>>', demog_node_context)
            if demog_node_context == '':
                continue 
            for edge, weight, primary_node_name in edges:
                edge_context = self.build_edge_context(edge, demog_node_name, primary_node_name, weight)
                demog_context += f"{edge_context}\n"
            
            demog_context += f"{demog_node_context}\n"
            # data.update(demog_node_data)
        full_context = 'Matched nodes:\n\n' + primary_context 
        if related_context != '':
            full_context += '\n' + 'Nodes related to matched nodes which might be helpful:\n\n' + related_context  
        # if demog_context != '':
        #     full_context += '\n' + 'Demographic Information:\n' + demog_context

     
        return full_context, data
    
    
    def build_context_old(
        self,
        result_dict,
        start_date: str,
        time_range: str,    
    ):  
        full_context = ''
        data = {}
        # data_root_path = self.user_path
        context_dict = {}

        full_context += f"Matched Entities:\n\n"
        for entity, inner_dict in result_dict.items():
            context ='' #context for each entity
            primary_node = inner_dict['primary_node']
            primary_node_context, primary_node_data = self.build_node_context(primary_node, start_date, time_range)
            context += f"{primary_node_context}\n"
            data.update(primary_node_data)
        
            if len(inner_dict['related_nodes']) > 0:
                context += f"Nodes related to {primary_node.name} based on statistical correlation:\n"
                for related_node, edge, weight, abnormality in inner_dict['related_nodes']:
                    edge_context = self.build_edge_context(edge, related_node.name, primary_node.name, weight)
                    context += f"{edge_context}\n"
                    related_node_context, related_node_data = self.build_node_context(related_node, start_date, time_range)
                    context += f"{related_node_context}\n"
                    data.update(related_node_data)
                    if abnormality is not None:
                        context += f"Abnormality of recent {time_range} days compared to individual's average: {abnormality}\n"
                        # context += f"Recent Abnormality: {recent_abnormality}\n"
                    else:
                        context += f"Abnormality of recent {time_range} days compared to individual's average: No abnormality score\n"
                        # context += f"Recent Abnormality: No recent abnormality score\n"
                    

            full_context += f"{context}\n"
            context_dict[entity] = context
     
        return full_context, data, context_dict
    

    
    def build_demog_node_context(
        self,
        node
    ):
        context =''
        if self.dataset_name in node.dataSource:
            df = self.get_demog_data(self.par_df.copy(), node.name)
            if df.empty:
                return context
            mark_data = df.to_markdown(index=False)
        
            sensor_specific_info = f"1.description: {node.dataSource[self.dataset_name]['description']}\n"
            sensor_specific_info += f"2.range: {node.dataSource[self.dataset_name]['range']}\n"
            sensor_specific_info += f"3.unit: {node.dataSource[self.dataset_name]['unit']}\n"
            context +=(f"{node.name}:\n"
                    f"description: {node.description}\n"
                    f"range: {node.range}\n"
                    f"recommendation: {node.recommendation}\n"
                    f"sensor specific information if available: \n {sensor_specific_info}\n"
                )
    
            context += (
                f"{mark_data}\n"
            )

     
        return context 

    def build_node_context(
        self,
        node,
        start_date,
        time_range,
        abnormality = None
    ):
        context =''
        data = {}
        if self.dataset_name in node.dataSource:
            sensor_specific_info = f"1.description: {node.dataSource[self.dataset_name]['description']}\n"
            sensor_specific_info += f"2.range: {node.dataSource[self.dataset_name]['range']}\n"
            sensor_specific_info += f"3.unit: {node.dataSource[self.dataset_name]['unit']}\n"
            context +=(f"{node.name}:\n"
                    # f"description: {node.description}\n"
                    # f"range: {node.range}\n"
                    # f"recommendation: {node.recommendation}\n"
                    f"sensor specific information if available: \n{sensor_specific_info}"
                )
            context += 'Data:\n'
            node_name = node.name
                # df = self.get_data(data_root_path,path, start_date, end_date,[id])
            df = self.get_data(self.par_df.copy(), start_date, time_range,node_name)
            mark_data = df.to_markdown(index=False)
            trace = {
            "x": df['date'].tolist(),
            "y": df[node_name].tolist(),
            }
            data[node_name] = trace
            context += (
                f"{mark_data}\n"
            )
            # if abnormality is not None:
            #     context += f"Abnormality of recent {time_range} days compared to individual's average: {abnormality}\n"
            #     # context += f"Recent Abnormality: {recent_abnormality}\n"
            # else:
            #     context += f"Abnormality of recent {time_range} days compared to individual's average: No abnormality level\n"
            if abnormality is not None:
                context += (
                    f"Recent {time_range}-day value deviates from the individual's average by "
                    f"{float(abnormality):.2f} standard deviations.\n"
                )
            else:
                context += (
                    f"No deviation from baseline recorded for the recent {time_range}-day period.\n\n"
                )
        else:
            context += (
                f"data: No data\n"
            )
     
        return context, data
    def build_edge_context(
        self,
        edge,
        related_node_name,
        primary_node_name,
        weight
    ):
        context =''
        context += f"{related_node_name} is related to {primary_node_name}:\n"
        # context += f"{related_node_name} is related to {primary_node_name} with weight {weight}:\n"
        # context += f"relationship between {edge.entity_1_name} and {edge.entity_2_name}:\n {edge.description} \n"
        context += f"{edge.description} \n"
        
     
        return context
    
    
    def get_data(
        self,
        par_df,
        start_date: str ="",
        time_range: str = "",
        feat_name: str = "",
        date_column: str = "date",
    ) -> pd.DataFrame:
        # print(par_df)
        # print(start_date,end_date,feat_name)
        # print(self.df.columns)
        if feat_name not in par_df.columns:
            print(f"Feature {feat_name} not found in the dataset")
            return pd.DataFrame()
        
        usecols = ['date',feat_name]
        selected_df = par_df[usecols].copy()
        #just date no time
        
        selected_df['date'] = pd.to_datetime(selected_df['date'])
        temp_df = selected_df[selected_df['date'] <= start_date].sort_values(by='date').copy()
        if time_range == 'all' or len(temp_df) < int(time_range) + 1:
            time_range = len(temp_df)
        temp_df['date'] = pd.to_datetime(temp_df['date']).dt.date
        selected_rows = temp_df.iloc[-int(time_range):]

        

        if selected_rows.empty:
            print(
                f"No data found between the date {start_date} and {time_range} days ago."
            )
        return selected_rows
    
    def get_demog_data(
        self,
        par_df,
        feat_name: str = "",
    ) -> pd.DataFrame:
     
        if feat_name not in par_df.columns:
            print(f"Feature {feat_name} not found in the dataset")
            return pd.DataFrame()
        #get the latest non-null value
        usecols = ['date', feat_name]
        selected_df = par_df[usecols].copy()
        # Get rows where feat_name is not null, sort by date descending
        filtered_df = selected_df[selected_df[feat_name].notnull()].sort_values(by='date', ascending=False).copy()
        
        # Check if any rows exist before accessing iloc[0]
        if len(filtered_df) > 0:
            selected_rows = filtered_df.iloc[0]
        else:
            selected_rows = pd.DataFrame() # Return empty DataFrame if no valid rows
            # print(
            #     f"No demographic data found for {feat_name}."
            # )
        return selected_rows
