import argparse

import numpy as np

import sys
import os
import json
from typing import Dict, List

from agent.nesy_agent.nesy_agent import NesyAgent
from utils.time_utils import TimeUtils




class RuleDrivenAgent(NesyAgent):
    def __init__(self, backbone_llm=None, **kwargs):
        kwargs["method"] = "RuleNeSy"
        super().__init__(backbone_llm=backbone_llm, **kwargs)

    def ranking_intercity_transport_go(self, transport_info:List[Dict], query:Dict)->List[int]:
        time_list = [transport["depature_time"] for transport in transport_info]
        sorted_lst = sorted(enumerate(time_list), key=lambda x: x[1])
        sorted_indices = [index for index, value in sorted_lst]
        time_ranking = np.zeros_like(sorted_indices)
        for i, idx in enumerate(sorted_indices):
            time_ranking[idx] = i + 1

        price_list = [transport["price"] for transport in transport_info]
        price_ranking = np.argsort(np.array(price_list))

        ranking_idx = np.argsort(time_ranking + price_ranking)

        return ranking_idx

    def ranking_intercity_transport_back(self, transport_info:List[Dict], query:Dict, selected_go:List[int])->List[int]:
        time_list = [transport["depature_time"] for transport in transport_info]
        sorted_lst = sorted(enumerate(time_list), key=lambda x: x[1])
        sorted_indices = [index for index, value in sorted_lst]
        sorted_indices.reverse()  # departure time from late to early
        time_ranking = np.zeros_like(sorted_indices)
        for i, idx in enumerate(sorted_indices):
            time_ranking[idx] = i + 1

        price_list = [transport["price"] for transport in transport_info]
        price_ranking = np.argsort(np.array(price_list))

        ranking_idx = np.argsort(time_ranking + price_ranking)

        return ranking_idx

    def ranking_hotel(self, hotel_info:List[Dict], query:Dict)->List[int]:

        # ranking by cost
        cost_list = []
        for hotel_id in hotel_info:
            hotel = self.poi_analyzer.get_hotel_info(hotel_id, query["locale"])
            cost_list.append(hotel['price'])
        sorted_lst = sorted(enumerate(cost_list), key=lambda x: x[1])
        sorted_indices = [index for index, value in sorted_lst]

        return sorted_indices[:5]
    
    def check_if_too_late(
        self, query, current_day, current_time, current_position, poi_plan
     ):

        if current_time != "" and TimeUtils.compare_times_later("23:00", current_time)==-1:
            print("too late, after 23:00")
            return True

        if current_time != "" and current_day == query["day"] - 1:
            # We should go back in time ...
            transports_ranking = self.ranking_innercity_transport(
                current_position,
                poi_plan["back_transport"]["key"].split("->")[0],
                current_day,
                current_time,
            )

            for transport_type_sel in transports_ranking:

                flag = True
                try:
                    back_position = (poi_plan["back_transport"]["from"]["lat"], poi_plan["back_transport"]["from"]["lng"])
                except: 
                    back_position = current_position

                if "back_transport" in poi_plan:
                    transports_sel = self.collect_innercity_transport(
                        query["arrive"],
                        current_position,
                        back_position,
                        current_time,
                        transport_type_sel,
                    )
                    if not isinstance(transports_sel, list):
                            continue
                    if len(transports_sel) == 0:
                        arrived_time = current_time
                    else:
                        arrived_time = transports_sel[-1]["end_time"]

                    if TimeUtils.compare_times_later(poi_plan["back_transport"]["depature_time"],arrived_time)>=0:
                        flag = False
                if flag:
                    print(
                        "Can not go back source-city in time, current POI {}, station arrived time: {}".format(
                            current_position, arrived_time
                        )
                    )
                    return True

        elif current_time != "":
            if "accommodation" in poi_plan:
                hotel_sel = poi_plan["accommodation"]
                transports_ranking = self.ranking_innercity_transport(
                    current_position,(hotel_sel["lat"], hotel_sel["lon"]), current_day, current_time
                )

                for transport_type_sel in transports_ranking:

                    flag = True
                    if "back_transport" in poi_plan:
                        transports_sel = self.collect_innercity_transport(
                            query["arrive"],
                            current_position,
                            (hotel_sel["lat"], hotel_sel["lon"]),
                            current_time,
                            transport_type_sel,
                        )

                        flag = True
                        if not isinstance(transports_sel, list):
                            continue
                        if len(transports_sel) == 0:
                            arrived_time = current_time
                        else:
                            arrived_time = transports_sel[-1]["end_time"]
                        # the check meaning of this code: check if the time to arrive the hotel has exceeded 24:00 (midnight), if it has, set flag to False, indicating that it is too late to return to the hotel
                        if TimeUtils.compare_times_later(arrived_time, "24:00") >= 0:
                            flag = False
                    if flag:
                        print(
                            "Can not go back to hotel, current POI {}, hotel arrived time: {}".format(
                                current_position, arrived_time
                            )
                        )
                        return True

        return False

    def select_next_poi_type(
        self,
        candidates_type,
        plan,
        poi_plan,
        current_day,
        current_time,
        current_position,
     ):

        if current_day == self.query["day"] - 1:
            # 3 hours later to arrive the destination
            if TimeUtils.compare_times_later(TimeUtils.time_operation_add(current_time, 3),poi_plan["back_transport"]["depature_time"])>=0:
                return "back-intercity-transport", ["back-intercity-transport"]

        candidates_type = ["attraction"]
        if "accommodation" in poi_plan and current_day < self.query["day"] - 1:
            candidates_type.append("hotel")

        # too late
        if TimeUtils.compare_times_later(TimeUtils.time_operation_add(current_time, 2.0), "22:30") >0  and "hotel" in candidates_type:
            return "hotel", ["hotel"]

        return "attraction", candidates_type

    def ranking_attractions(
        self, 
        current_time:str,
        current_position:tuple,
        intercity_with_hotel_cost:float
     )->List[int]:

        # ranking by distance
        num_attractions = len(self.memory["attractions"])
        attr_info = []
        for poiid in self.memory["attractions"]:
            poi_api_info, min_spend_time, poi_info = self.poi_analyzer.read_poi_api_info(poiid, self.query["locale"])
            attr_info.append(poi_api_info)

        attr_dist = []
        if current_position != (0,0):
            for i in range(num_attractions):
                transports_sel = self.collect_innercity_transport(
                    self.query["target_city"],
                    current_position,
                    (attr_info[i]['lat'], attr_info[i]['lon']),
                    current_time,
                    "walk",
                )
                if len(transports_sel) == 0:
                    attr_dist.append(0)
                else:
                    attr_dist.append(transports_sel[0]["distance"])
            # print(attr_dist)
        else:
            attr_dist = [0 for i in range(num_attractions)]
            
        attr_price = [attr_info[i].get("price", 0) for i in range(num_attractions)]

        ranking_price = np.argsort(np.array(attr_price))

        ranking_dist = np.argsort(np.array(attr_dist))

        ranking_idx = np.argsort(ranking_price + ranking_dist)


        return ranking_idx

    def ranking_innercity_transport(
        self, current_position, target_position, current_day, current_time
     ):

        return ["metro", "taxi", "walk"]

    def ranking_innercity_transport_from_query(self, query:Dict)->List[str]:

        return ["metro", "taxi", "walk"]

    def select_poi_time(
        self,
        plan,
        poi_plan,
        current_day,
        start_time,
        poi_name,
        poi_type,
        poi_hours
     ):
        if poi_hours:
            min_hours, max_hours = poi_hours
            return min_hours
        else:
            return 1.5

    def decide_rooms(self, query):
        return None, None
    def extract_budget(self, query):
        return None







