from typing import List, Dict, Optional
from pydantic import BaseModel
import json
from models.entity import Entity
from models.relationship import Relationship
from openai import OpenAI
import os
from query.base import Base
from utils.rag import search_docs, get_data
from prompt.Query import QUERY_PROMPT_KGRAG_v3
from retriever.wag_retriever import *
import numpy as np
class Wag(Base):
    """Base class for query operations"""
    def __init__(self, max_hops: int = 2, min_confidence: float = 0.7,  **kwargs):
        super().__init__(
            query_prompt=QUERY_PROMPT_KGRAG_v3,
            **kwargs
        )
        self.max_hops = max_hops
        self.min_confidence = min_confidence

   

    def search_knowledge_graph(self, 
                               entities: List[str], 
                               query_info,
                               **kwargs
                               ):
        result_dict = {}
        #result_dict structure:
        # {
        #     entity: {
        #         'similarity': similarity,
        #          'primary_node':[node class, abnormality]
        #         'related_nodes': [related_node class, edge class, weight, abnormality]
        #         'demog_nodes': [demog_node class, edge class, weight]
        #     }
        #     ...
        # }

        node_name_map = {node.name: node for node in self.nodes_with_embeddings.values()}
        # demog_dict = {}
        # for demog_metric in self.param['demog_metrics']:
        #     demog_dict[demog_metric] = self.nodes_with_embeddings[demog_metric] 
        related_node_edge_pair = {}
        #TODO: handle case when same node is matched multiple times 
        exclude_entity_ids = []
        #find the matched entities
        for entity in entities:
            # res = search_docs(self.nodes_with_embeddings, entity, top_n=self.param['node_match_top_n'])
            res = []
            # print(self.nodes_with_embeddings)
            for id, node in self.nodes_with_embeddings.items():
    
                if node.name == entity:
                    res.append({'node': node, 'similarity': 1.0})
                    break
            r = res[0]
            if entity not in result_dict:
                result_dict[entity] = {'similarity': r['similarity'], 'primary_node':[r['node'], None], 'related_nodes':[], 'demog_nodes':[]}
                # matched_nodes.append(r)
            else:
                #update the entity with higher similarity
                if r['similarity'] > result_dict[entity]['similarity']:
                    result_dict[entity]['similarity'] = r['similarity']
                    result_dict[entity]['primary_node'] = [r['node'], None]
        # matched_nodes structure:
        # {
        #     'node': node class,
        #     'similarity': similarity score
        # }
        # if query_info['openness'] is np.nan:
        # query_info['openness'] = 1
        num_related_nodes = int(query_info['openness'] * self.param['max_num_related_nodes'])
        #number of related nodes per primary node
        num_related_nodes_per_node = num_related_nodes // len(result_dict)
        
        # print('num_related_nodes_per_node', num_related_nodes_per_node)
        for entity, inner_dict in result_dict.items():
            node_name = inner_dict['primary_node'][0].name#.capitalize()
            #query_info structure:
            # {
            #     'query_metric': metric,
            #     'query_date': query_date,
            #     'time_granularity': time_granularity,
            #     'openness': openness
            # }
            rel_info = kwargs['rel_info']
            # Get weights and metrics
            # print(node_name)
            out_df, abnormality_df = self.weight_recomputation(node_name, query_info, rel_info, self.prior_matrix, self.par_df.copy(), self.param)
       
            # split into low, medium, high by percentile
            # abnormality_df['abnormality'] = abnormality_df['abnormality'].apply(lambda x: 'low' if x < abnormality_df['abnormality'].quantile(0.33) else 'medium' if x < abnormality_df['abnormality'].quantile(0.66) else 'high')
            #update the abnormality of the primary node
            inner_dict['primary_node'][1] = abnormality_df.loc[node_name, 'abnormality']
            
            
            
            # Get top metrics and their scores
            weight_sort_by = self.param['weight_sort_by']
            # print(weight_sort_by)

            temp = out_df.sort_values(by=weight_sort_by, ascending=False)
            weight_final = temp[weight_sort_by]
       
            top_metrics = (
                out_df
                .reset_index()  # Convert index to a column temporarily
                .sort_values(
                    by=[weight_sort_by, 'index'],  # Now 'index' is a sortable column
                    ascending=[False, True]        # Descending for weight, ascending for index
                )
                .head(num_related_nodes_per_node)
                .set_index('index')  # Restore original index
            )

            ordered_metrics = top_metrics.index.tolist()
            edge_weights = top_metrics[weight_sort_by].tolist() 
            # abnormality_scores = top_metrics['recent_abnormality'].tolist()
       
            
            
            edge_dict = {
                edge.entity_2_name if edge.entity_1_name == node_name else edge.entity_1_name: edge
                for edge in self.edges.values()
                if node_name in (edge.entity_1_name, edge.entity_2_name)
            }

            # Build pairs in one pass
            result_dict[entity]['related_nodes'].extend(
                (node_name_map[metric], edge_dict[metric], weight, abnormality_df.loc[metric, 'abnormality'])
                for metric, weight in zip(ordered_metrics, edge_weights)
                if metric in edge_dict and metric in node_name_map
            )


            #add demog nodes
            for demog_metric in self.param['demog_metrics']:
                # print(demog_metric)
                result_dict[entity]['demog_nodes'].append((node_name_map[demog_metric], edge_dict[demog_metric], edge_dict[demog_metric].weight))

     
        return result_dict, out_df
                
    def weight_recomputation(self, 
                               node_name,
                               query_info,
                               rel_info, 
                               prior_matrix, 
                               par_df, 
                               hyperparameter,
                               ):
        """Weight recomputation for the given node"""

        numeric_metrics = rel_info['numeric_metrics']

        #case when node contains data that is not numeric
        # node_name #= query_info['query_metric']
        if node_name not in numeric_metrics:
            rel_info['rel_pop_all'] = None
            rel_info['rel_var_pop_all'] = None
            rel_info['rel_pop_sample_size_all'] = None
            rel_info['rel_ind_all'] = None
            rel_info['rel_var_ind_all'] = None
            rel_info['rel_ind_sample_size_all'] = None

  
        graph_searcher =  wag_retriever (node_name, rel_info, prior_matrix, hyperparameter, par_df, query_info)
        out_df, abnormality_df = graph_searcher.run()

        
        return out_df, abnormality_df

    #TODO: fast mode, node already available

        

    def process_query(self, user_query: str, today_date: str = "2020-07-23"):
        """Base method for processing queries"""
        raise NotImplementedError("Subclasses must implement process_query")


    


    def build_node_context(
        self,
        node,
        start_date,
        time_range,
        abnormality = None
    ):
        context =''
        data = {}
        if self.dataset_name in node.dataSource:
            sensor_specific_info = f"1.description: {node.dataSource[self.dataset_name]['description']}\n"
            sensor_specific_info += f"2.range: {node.dataSource[self.dataset_name]['range']}\n"
            sensor_specific_info += f"3.unit: {node.dataSource[self.dataset_name]['unit']}\n"
            context +=(f"{node.name}:\n"
                    f"description: {node.description}\n"
                    f"range: {node.range}\n"
                    f"recommendation: {node.recommendation}\n"
                    f"sensor specific information: \n{sensor_specific_info}"
                )
            context += 'Data:\n'
            node_name = node.name
                # df = self.get_data(data_root_path,path, start_date, end_date,[id])
            df = self.get_data(self.par_df.copy(), start_date, time_range,node_name)
            mark_data = df.to_markdown(index=False)
            trace = {
            "x": df['date'].tolist(),
            "y": df[node_name].tolist(),
            }
            data[node_name] = trace
            context += (
                f"{mark_data}\n"
            )
    
            
            if abnormality is not None:
                context += (
                    f"Recent {time_range}-day value deviates from the individual's average by "
                    f"{float(abnormality):.2f} standard deviations.\n"
                )
            else:
                context += (
                    f"No deviation from baseline recorded for the recent {time_range}-day period.\n\n"
                )


        else:
            context += (
                f"data: No data\n"
            )
     
        return context, data
        

    def build_feature_context(
        self,
        matched_entities,
        related_entities,
        start_date: str ="2020-07-23",
        end_date: str = "2020-07-27",
    ):
        data = {}
        context ='Matched Entities:\n'
        for entity_dict in matched_entities:
            entity = entity_dict['node']
            if self.dataset_name in entity.dataSource:
                sensor_specific_info = f"1.description: {entity.dataSource[self.dataset_name]['description']}\n"
                sensor_specific_info += f"2.range: {entity.dataSource[self.dataset_name]['range']}\n"
                sensor_specific_info += f"3.unit: {entity.dataSource[self.dataset_name]['unit']}\n"
                context +=(f"{entity.name}:\n"
                        f"description: {entity.description}\n"
                        f"range: {entity.range}\n"
                        f"recommendation: {entity.recommendation}\n"
                        f"sensor specific information: {sensor_specific_info}\n"
                    )
           
                feat_name = entity.name
                    # df = self.get_data(data_root_path,path, start_date, end_date,[id])
                df = self.get_data(start_date, end_date,feat_name)
                mark_data = df.to_markdown(index=False)
                trace = {
                "x": df['date'].tolist(),
                "y": df[feat_name].tolist(),
                }
                data[feat_name] = trace
                context += (
                    f"{mark_data}\n"
                )
            else:
                context += (
                    f"data: No data\n"
                )

        context += 'Related Entities:\n'
        for (entity, weight) in related_entities:
            # entity, weight = entity
            if self.dataset_name in entity.dataSource:
                sensor_specific_info = f"1.description: {entity.dataSource[self.dataset_name]['description']}\n"
                sensor_specific_info += f"2.range: {entity.dataSource[self.dataset_name]['range']}\n"
                sensor_specific_info += f"3.unit: {entity.dataSource[self.dataset_name]['unit']}\n"
                context +=(f"{entity.name}:\n"
                        f"weight: {weight}\n"
                        f"description: {entity.description}\n"
                        f"range: {entity.range}\n"
                        f"recommendation: {entity.recommendation}\n"
                        f"sensor specific information: {sensor_specific_info}\n"
                    )
                feat_name = entity.name
                    # df = self.get_data(data_root_path,path, start_date, end_date,[id])
                df = self.get_data(start_date, end_date,feat_name)
                mark_data = df.to_markdown(index=False)
                trace = {
                "x": df['date'].tolist(),
                "y": df[feat_name].tolist(),
                }
                data[entity.name] = trace
                context += (
                    f"{mark_data}\n"
                )
            else:
                context += (
                    f"data: No data\n"
                )
            # feat = entity['node'
        # print(context)
        return context, data

