import pandas as pd
from langchain.tools import tool

class Analytics:
    def __init__(self):
        # 初始化时加载数据
        self.ANALYTICS_DATA = pd.read_csv("data/processed/analytics_data.csv", dtype=str)
        self.ANALYTICS_DATA["user_engaged"] = self.ANALYTICS_DATA["user_engaged"] == "True"  # 转换为布尔值
        self.PLOTS_DATA = pd.DataFrame(columns=["file_path"])
        self.METRICS = ["total_visits", "session_duration_seconds", "user_engaged"]
        self.METRIC_NAMES = ["total visits", "average session duration", "engaged users"]

    def reset_state(self):
        """Resets the analytics data to the original state."""
        self.ANALYTICS_DATA = pd.read_csv("data/processed/analytics_data.csv", dtype=str)
        self.ANALYTICS_DATA["user_engaged"] = self.ANALYTICS_DATA["user_engaged"] == "True"  # 转换为布尔值
        self.PLOTS_DATA = pd.DataFrame(columns=["file_path"])

    def get_visitor_information_by_id(self, visitor_id=None):
        """
        Returns the analytics data for a given visitor ID.

        Parameters
        ----------
        visitor_id : str, optional
            ID of the visitor.

        Returns
        -------
        visitor_data : dict
            Analytics data for the given visitor ID.
        """
        if not visitor_id:
            return "Visitor ID not provided."
        visitor_data = self.ANALYTICS_DATA[self.ANALYTICS_DATA["visitor_id"] == visitor_id].to_dict(orient="records")
        if visitor_data:
            return visitor_data
        else:
            return "Visitor not found."

    def create_plot(self, time_min=None, time_max=None, value_to_plot=None, plot_type=None):
        """
        Plots the analytics data for a given time range and value.

        Parameters
        ----------
        time_min : str, optional
            Start date of the time range. Date format is "YYYY-MM-DD".
        time_max : str, optional
            End date of the time range. Date format is "YYYY-MM-DD".
        value_to_plot : str, optional
            Value to plot. Available values are: "total_visits", "session_duration_seconds", "user_engaged", "visits_direct", "visits_referral", "visits_search_engine", "visits_social_media"
        plot_type : str, optional
            Type of plot. Can be "bar", "line", "scatter" or "histogram"

        Returns
        -------
        file_path : str
            Path to the plot file. Filename is {{time_min}}_{{time_max}}_{{value_to_plot}}_{{plot_type}}.png.
        """
        if not time_min:
            return "Start date not provided."
        if not time_max:
            return "End date not provided."
        if value_to_plot not in [
            "total_visits",
            "session_duration_seconds",
            "user_engaged",
            "visits_direct",
            "visits_referral",
            "visits_search_engine",
            "visits_social_media",
        ]:
            return "Value to plot must be one of 'total_visits', 'session_duration_seconds', 'user_engaged', 'direct', 'referral', 'search engine', 'social media'"
        if plot_type not in ["bar", "line", "scatter", "histogram"]:
            return "Plot type must be one of 'bar', 'line', 'scatter', or 'histogram'"

        # Plot the data here and save it to a file
        file_path = f"plots/{time_min}_{time_max}_{value_to_plot}_{plot_type}.png"
        self.PLOTS_DATA.loc[len(self.PLOTS_DATA)] = [file_path]
        return file_path

    def total_visits_count(self, time_min=None, time_max=None):
        """
        Returns the total number of visits within a specified time range.

        Parameters
        ----------
        time_min : str, optional
            Start date of the time range. Date format is "YYYY-MM-DD".
        time_max : str, optional
            End date of the time range. Date format is "YYYY-MM-DD".

        Returns
        -------
        total_visits : dict
            Total number of visits in the specified time range.
        """
        data = self.ANALYTICS_DATA
        if time_min:
            data = data.loc[data["date_of_visit"] >= time_min]
        if time_max:
            data = data.loc[data["date_of_visit"] <= time_max]

        data = data.copy()  # 避免 SettingWithCopyWarning
        return data.groupby("date_of_visit").size().to_dict()

    def engaged_users_count(self, time_min=None, time_max=None):
        """
        Returns the number of engaged users within a specified time range.

        Parameters
        ----------
        time_min : str, optional
            Start date of the time range. Date format is "YYYY-MM-DD".
        time_max : str, optional
            End date of the time range. Date format is "YYYY-MM-DD".

        Returns
        -------
        engaged_users : dict
            Number of engaged users in the specified time range.
        """
        data = self.ANALYTICS_DATA
        if time_min:
            data = data.loc[data["date_of_visit"] >= time_min]
        if time_max:
            data = data.loc[data["date_of_visit"] <= time_max]

        data = data.copy()  # 避免 SettingWithCopyWarning
        data["user_engaged"] = data["user_engaged"].astype(bool).astype(int)

        return data.groupby("date_of_visit").sum()["user_engaged"].to_dict()

    def traffic_source_count(self, time_min=None, time_max=None, traffic_source=None):
        """
        Returns the number of visits from a specific traffic source within a specified time range.

        Parameters
        ----------
        time_min : str, optional
            Start date of the time range. Date format is "YYYY-MM-DD".
        time_max : str, optional
            End date of the time range. Date format is "YYYY-MM-DD".
        traffic_source : str, optional
            Traffic source to filter the visits. Available values are: "direct", "referral", "search engine", "social media"

        Returns
        -------
        traffic_source_visits : dict
            Number of visits from the specified traffic source in the specified time range.
        """
        data = self.ANALYTICS_DATA
        if time_min:
            data = data.loc[data["date_of_visit"] >= time_min]
        if time_max:
            data = data.loc[data["date_of_visit"] <= time_max]

        data = data.copy()  # 避免 SettingWithCopyWarning

        if traffic_source:
            data["visits_from_source"] = (data["traffic_source"] == traffic_source).astype(int)
            return data.groupby("date_of_visit").sum()["visits_from_source"].to_dict()
        else:
            return data.groupby("date_of_visit").size().to_dict()

    def get_average_session_duration(self, time_min=None, time_max=None):
        """
        Returns the average session duration within a specified time range.

        Parameters
        ----------
        time_min : str, optional
            Start date of the time range. Date format is "YYYY-MM-DD".
        time_max : str, optional
            End date of the time range. Date format is "YYYY-MM-DD".

        Returns
        -------
        average_session_duration : float
            Average session duration in seconds in the specified time range.
        """
        data = self.ANALYTICS_DATA
        if time_min:
            data = data.loc[data["date_of_visit"] >= time_min]
        if time_max:
            data = data.loc[data["date_of_visit"] <= time_max]

        data = data.copy()  # 避免 SettingWithCopyWarning
        data["session_duration_seconds"] = data["session_duration_seconds"].astype(float)

        return (
            data[["date_of_visit", "session_duration_seconds"]]
            .groupby("date_of_visit")
            .mean()["session_duration_seconds"]
            .to_dict()
        )
