import argparse
import re
import numpy as np
import time
import sys
import os
import json
from typing import Dict, List


from agent.nesy_agent.nesy_agent import NesyAgent
from utils.time_utils import TimeUtils

from agent.nesy_agent.prompts.PROMPTS import (
    NEXT_POI_TYPE_INSTRUCTION,
    INTERCITY_TRANSPORT_GO_INSTRUCTION,
    INTERCITY_TRANSPORT_BACK_INSTRUCTION,
    HOTEL_RANKING_INSTRUCTION,
    ATTRACTION_RANKING_INSTRUCTION,
    RESTAURANT_RANKING_INSTRUCTION,
    SELECT_POI_TIME_INSTRUCTION, 
    ROOMS_PLANNING_INSTRUCTION, 
    BUDGETS_INSTRUCTION, 
    INNERCITY_TRANSPORTS_SELECTION_INSTRUCTION, 
)


class LLMDrivenAgent(NesyAgent):
    def __init__(self, **kwargs):
        kwargs["method"] = "LLMNeSy"
        super().__init__(**kwargs)
        # self.ret=Retriever()
        self.backbone_llm = kwargs["backbone_llm"]
        self.ranking_attractions_flag= False
        self.ranking_restaurants_flag= False

    def ranking_intercity_transport_go(self, transport_info:List[Dict], query:Dict):

        time_before = time.time()
        
        query_message = INTERCITY_TRANSPORT_GO_INSTRUCTION.format(user_requirements=query['userQuery'], transport_info=str(transport_info))
        answer = self.backbone_llm(query_message)

        self.llm_inference_time_count += time.time() - time_before

        self.llm_rec_count += 1

        print(answer)
        match = re.search(r'IDList:\s*(\[[^\]]+\])', answer)
        try:
            intercity_transport_list = eval(match.group(1))
            print('selected intercity_transports: ',intercity_transport_list) 
            # print(intercity_transport_list)

            ranking_idx = []
            for cand_i in intercity_transport_list:
                for transport in transport_info:    
                    if transport['planid'] == cand_i:
                        ranking_idx.append(transport_info.index(transport))
                    
        except Exception as e:
            print("!!!Error in eval intercity_transport_list", e)

            self.llm_rec_format_error += 1

            time_list = [transport["depature_time"] for transport in transport_info]
            sorted_lst = sorted(enumerate(time_list), key=lambda x: x[1])
            sorted_indices = [index for index, value in sorted_lst]
            time_ranking = np.zeros_like(sorted_indices)
            for i, idx in enumerate(sorted_indices):
                time_ranking[idx] = i + 1

            price_list = [transport["price"] for transport in transport_info]
            price_ranking = np.argsort(np.array(price_list))

            ranking_idx = np.argsort(time_ranking + price_ranking)
    
        return ranking_idx
        
    
    def ranking_intercity_transport_back(self, transport_info, query, selected_go):


        time_before = time.time()
        query_message = INTERCITY_TRANSPORT_BACK_INSTRUCTION.format(user_requirements=query['userQuery'], transport_info=str(transport_info), selected_go_info=str(selected_go))
        answer = self.backbone_llm(query_message)

        self.llm_inference_time_count += time.time() - time_before

        self.llm_rec_count += 1
        print(answer)
        match = re.search(r'IDList:\s*(\[[^\]]+\])', answer)
        try:
            intercity_transport_list = eval(match.group(1))
            print('selected intercity_transports: ',intercity_transport_list) 

            # print(intercity_transport_list)

            ranking_idx = []
            for cand_i in intercity_transport_list:
                for transport in transport_info:
                    if transport['planid'] == cand_i:
                        ranking_idx.append(transport_info.index(transport))
                        
        except Exception as e:
            print("!!!Error in eval intercity_transport_list", e)
            self.llm_rec_format_error += 1
            
            time_list = [transport["depature_time"] for transport in transport_info]
            sorted_lst = sorted(enumerate(time_list), key=lambda x: x[1])
            sorted_indices = [index for index, value in sorted_lst]
            sorted_indices.reverse()  
            time_ranking = np.zeros_like(sorted_indices)
            for i, idx in enumerate(sorted_indices):
                time_ranking[idx] = i + 1

            price_list = [transport["price"] for transport in transport_info]
            price_ranking = np.argsort(np.array(price_list))

            ranking_idx = np.argsort(time_ranking + price_ranking)

        return ranking_idx
    
    def ranking_hotel(self, hotel_info, query):
        
        info_list = []
        for hotel_id in hotel_info:
            hotel = self.poi_analyzer.get_hotel_info(hotel_id, query["locale"])
            info_list.append(hotel)

        time_before = time.time()
        # filtering hotel info by llm
        query_message = HOTEL_RANKING_INSTRUCTION.format(user_requirements=query['userQuery'], hotel_info=str(info_list))
        answer = self.backbone_llm(query_message)


        self.llm_inference_time_count += time.time() - time_before
        self.llm_rec_count += 1

        print(answer)
        match = re.search(r'HotelIDList:\s*\[(.*?)\]', answer, re.DOTALL)
        
        ranking_idx = []
        try:
            HotelIDList = re.findall(r'"([^"]+)"', match.group(1))
    
            print('selected HotelNameList: ',HotelIDList) 
            for cand_i in HotelIDList:
                ranking_idx.append(hotel_info.index(cand_i))
        except:
            print("!!!Error in eval HotelIDList")
            self.llm_rec_format_error += 1
            
            cost_list = []
            for hotel_id in hotel_info:
                hotel = self.poi_analyzer.get_hotel_info(hotel_id, query["locale"])
                cost_list.append(hotel['price'])
            sorted_lst = sorted(enumerate(cost_list), key=lambda x: x[1])
            ranking_idx = [index for index, value in sorted_lst]

        return ranking_idx
    

    def select_next_poi_type(self, candidates_type, plan, poi_plan, current_day, current_time, current_position):
        
        if current_day == self.query["day"] - 1:
            # 3 hours later to the destination
            if TimeUtils.compare_times_later(TimeUtils.time_operation_add(current_time, 3),poi_plan["back_transport"]["depature_time"])>=0:
                return "back-intercity-transport", ["back-intercity-transport"]
        
        time_before = time.time()
        query_message = NEXT_POI_TYPE_INSTRUCTION.format(self.query['userQuery'], poi_plan,current_day+1, current_time, current_position,candidates_type)
        answer=self.backbone_llm(query_message)

        self.llm_rec_count += 1

        
        self.llm_inference_time_count += time.time() - time_before

        poi_type=None
        match = re.search(r'Type:\s*(\w+)', answer)
        if match:
            poi_type = match.group(1)
        else:
            self.llm_rec_format_error += 1
            
        if poi_type is not None and poi_type in candidates_type:
            return poi_type, candidates_type
        else:
            print("The selected POI type is not in the candidate POI type list.")
            return candidates_type[0], candidates_type     
    
    def ranking_attractions(self, current_time, current_position, intercity_with_hotel_cost):
        attr_info = []
        for poiid in self.memory["attractions"]:
            poi_api_info, min_spend_time, poi_info = self.poi_analyzer.read_poi_api_info(poiid, self.query["locale"])
            opentime = self.poi_analyzer.parse_poi_time_window(poi_info)
            poi_api_info['id'] = poi_api_info['id'].replace("poi_", "")
            poi_api_info['opentime'] = opentime.split("-")[0]
            poi_api_info['endtime'] = opentime.split("-")[1]
            attr_info.append(poi_api_info)

        if self.ranking_attractions_flag:
            pass
        else:
            time_before = time.time()
            query_message = ATTRACTION_RANKING_INSTRUCTION.format(user_requirements=self.query['userQuery'], attraction_info=str(attr_info), past_cost=intercity_with_hotel_cost)
            answer=self.backbone_llm(query_message)


            self.llm_inference_time_count += time.time() - time_before

            attraction_list=[]
            match = re.search(r'AttractionIDList:\s*(\[[^\]]+\])', answer)
            if match:
                try:
                    attraction_list = eval(match.group(1))
                except:
                    print("!!!Error in eval attraction_list")
            print('selected attractions: ',attraction_list)    
            self.suggested_attractions_from_query = attraction_list  
            self.ranking_attractions_flag = True

        attraction_list = self.suggested_attractions_from_query
        num_attractions = len(self.memory["attractions"])

        attr_price = [attr_info[i].get("price", 0) for i in range(num_attractions)]

        ranking_price = np.argsort(np.array(attr_price))

        attr_dist = []
        if current_position != (0,0):
            for i in range(num_attractions):
                if current_position == (attr_info[i]['lat'], attr_info[i]['lon']):
                    attr_dist.append(0)
                else:
                    transports_sel = self.collect_innercity_transport(self.query["target_city"], current_position, (attr_info[i]['lat'], attr_info[i]['lon']), current_time, "walk")
                    attr_dist.append(transports_sel[0]["distance"])
            
            ranking_dist = np.argsort(np.array(attr_dist))
        else:
            attr_dist = [0 for i in range(num_attractions)]
            ranking_dist = np.argsort(np.array(attr_dist))

        if len(attraction_list) > 0:
            for id, selected_i in enumerate(attraction_list):
                attr_i = self.memory["attractions"].index(selected_i)
                ranking_price[attr_i] = -len(attraction_list) + id
                ranking_dist[attr_i] = -len(attraction_list) + id
            

        ranking_idx = np.argsort(ranking_price + ranking_dist)
        
        return ranking_idx

    def select_poi_time(self, plan, poi_plan, current_day, start_time, poi_name, poi_type, recommended_visit_time):
        """
        if recommended_visit_time:
            min_hours, max_hours = recommended_visit_time
            return min_hours
        else:
            return 1.5
        """
        
        time_before = time.time()
        min_hours, max_hours = recommended_visit_time
        query_message = SELECT_POI_TIME_INSTRUCTION.format(user_requirements=self.query['userQuery'], 
                                                        current_travel_plans = plan,
                                                        current_date = current_day+1,
                                                        current_time = start_time,
                                                        current_poi = poi_name,
                                                        poi_type = poi_type,
                                                        recommended_visit_time = min_hours,
                                                        back_transport_time = poi_plan["back_transport"]["depature_time"])
        answer=self.backbone_llm(query_message)

        self.llm_rec_count += 1

        
        self.llm_inference_time_count += time.time() - time_before

        rec_time=None
        match = re.search(r'Time:\s*(\d+)', answer)
        if match:
            rec_time = match.group(1)
        else:
            self.llm_rec_format_error += 1
            
        if rec_time is not None:
            return float(rec_time)
        else:
            print("The selected POI time is not in the candidate POI time list.")
            return min_hours

    def decide_rooms(self, query):

        time_before = time.time()

        query_message = ROOMS_PLANNING_INSTRUCTION.format(user_requirements=self.query['userQuery'])
        answer=self.backbone_llm(query_message)


        self.llm_inference_time_count += time.time() - time_before
        
        self.llm_rec_count += 1

        room_info_pattern = re.compile(r'RoomInfo:\s*\[\s*(\d+|\-1)\s*,\s*(\d+|\-1)\s*\]')
    
        match = room_info_pattern.search(answer)
    
        if match:
            num_rooms = int(match.group(1))
            num_beds = int(match.group(2))
            if num_rooms < 1:
                num_rooms = None
            if num_beds < 1:
                num_beds = None
        else:
            print("!!!Error in matching RoomInfo")
            num_rooms, num_beds = None, None

            self.llm_rec_format_error += 1
        
        
        print("extracted room_number: ", num_rooms, "room_type:", num_beds)
        return num_rooms, num_beds
    def extract_budget(self, query:Dict):

        time_before = time.time()

        query_message = BUDGETS_INSTRUCTION.format(user_requirements=self.query['userQuery'])
        answer=self.backbone_llm(query_message)


        self.llm_inference_time_count += time.time() - time_before

        self.llm_rec_count += 1

        budget_pattern = r"Budget: (\d+)"
    
        match = re.search(budget_pattern, answer)
    
        if match:
            budget = int(match.group(1))
            if budget < 1:
                budget = None
        else:
            print("!!!Error in extracting budget")
            budget = None

            self.llm_rec_format_error += 1
        
        
        # print(answer)
        print("extracted budget: ", budget)
        # exit(0)
        return budget
    
    def ranking_innercity_transport_from_query(self, query:Dict )->List[str]:
        
        time_before = time.time()
        
        query_message = INNERCITY_TRANSPORTS_SELECTION_INSTRUCTION.format(user_requirements=query['userQuery'])
        answer=self.backbone_llm(query_message)

        print(answer)

        self.llm_inference_time_count += time.time() - time_before

        self.llm_rec_count += 1

        match = re.search(r'TransportRanking:\s*\[(.*?)\]', answer, re.DOTALL)
        if match:
            try:
                TransportRanking = re.findall(r'"([^"]+)"', match.group(1))
            except:
                print("!!!Error in eval TransportRanking")
                self.llm_rec_format_error += 1
                TransportRanking = []

            print('selected TransportRanking: ',TransportRanking) 
            rank_ = []
            for item in TransportRanking:
                if item in ["metro", "taxi", "walk"]:
                    rank_.append(item)
            if len(rank_):
                TransportRanking = rank_
            else:
                TransportRanking = ["metro", "taxi", "walk"]
        else:
            TransportRanking = ["metro", "taxi", "walk"]
        if (len(TransportRanking) == 0):
            TransportRanking = ["metro", "taxi", "walk"]
        return TransportRanking

