from datetime import datetime
from typing import *
import dataclasses
import json
import numpy as np
import logging
import os
import time


@dataclasses.dataclass
class StepInfo:
    actions: List[int]
    agent_infos: List[Dict[str, Any]]  # None if the agent does not act
    rewards: List[float]
    dones: List[bool]
    env_info: Dict[int, str]


@dataclasses.dataclass
class PredictInfo:
    choose_action: int
    target_actions: List[int]
    agent_info: Dict[str, Any]
    reward: int

@dataclasses.dataclass
class PerceptionInfo:
    output: Dict
    label: Dict
    agent_info: Dict[str, Any]
    reward: float

class Logger:
    """A logger that writes messages to both console and files (TXT and JSON)."""

    def __init__(self, config):
        self.start_time = time.time()
        self.exp_name = config["experiment"].get("name", "default")
        self.num_episodes = config["experiment"].get("num_episodes", 1)
        self.env_name = config["environment"]["type"]
        self.num_agents = len(config["agents"])
        self.returns = np.zeros((self.num_episodes, self.num_agents))
        self.steps = np.zeros(self.num_episodes, dtype=int)
        self.human_agent = []
        self.human_agent_path = None
        for agent_config in config['agents']:
            if agent_config["type"] == "human_agent":
                self.human_agent.append(agent_config["params"]["id"])
                self.human_agent_path = agent_config["params"].get("path", "/share/public-nfs/vsbench-HCI/human")
        self.price_info = {
            f"episode_{i}": {
                "total_price": 0,
                "prompt_price": 0,
                "completion_price": 0,
            } for i in range(self.num_episodes)
        }
        self.results = {
            "config": config,
            "returns_mean": {f"agent_{i}": 0 for i in range(self.num_agents)},
            "returns_std": {f"agent_{i}": 0 for i in range(self.num_agents)},
            "steps_mean": 0,
            "steps_std": 0,
            "price_info": {
                "total_price": 0,
                "prompt_price": 0,
                "completion_price": 0,
            }
        }
        self.results.update({f"episode_{i}": dict() for i in range(self.num_episodes)})

        # log dir
        results_dir = config["experiment"].get("results_dir", "results")
        agents_name = []
        for agent_config in config["agents"]:
            type_ = agent_config["type"]
            model = agent_config["params"].get("model")
            if model is None:
                agents_name.append(type_)
            else:
                model_name = model["params"]["name"]
                agents_name.append(f"{type_}({model_name})")
        agents_name = "+".join(agents_name)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_dir = os.path.join(results_dir, "decision-making", self.env_name, agents_name, self.exp_name,
                                    timestamp)
        self.log_file = os.path.join(self.log_dir, f"output.log")
        self.json_file = os.path.join(self.log_dir, f"results.json")
        os.makedirs(self.log_dir, exist_ok=True)

        # init logger
        logger_name = f"{self.env_name}_{agents_name}_{os.getpid()}_{timestamp}"
        self.logger = logging.getLogger(logger_name)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False
        formatter = logging.Formatter("%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

        file_handler = logging.FileHandler(self.log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(file_handler)

        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(formatter)
        self.logger.addHandler(console_handler)

        self.logger.info(f"Text logging to {self.log_file}")
        self.logger.info(f"Json logging to {self.json_file}")
        self.logger.info(f"Environment: {self.env_name}")
        self.logger.info(f"Agents: {agents_name}")
        self.logger.info(f"Experiment: {self.exp_name}")

    def get_episode_dir(self, episode_id):
        episode_dir = os.path.join(self.log_dir, f"episode_{episode_id}")
        os.makedirs(episode_dir, exist_ok=True)
        return episode_dir

    def episode_start(self, episode_id):
        self.logger.info(f"===== Episode {episode_id} start =====")

    def step(self, episode_id, step_info):
        step = self.steps[episode_id]
        self.logger.info(f"step {step}: actions={step_info.actions}, env_info={step_info.env_info}")

        info = dict()
        for i, agent_info in enumerate(step_info.agent_infos):
            agent_info = agent_info or dict()  # if agent_info is None, initialize to empty dict
            agent_info["reward"] = step_info.rewards[i]
            agent_info["done"] = step_info.dones[i]
            info[f"agent_{i}"] = agent_info
        info["env_info"] = step_info.env_info
        self.results[f"episode_{episode_id}"][f"step_{step}"] = info

        for i, agent_info in enumerate(step_info.agent_infos):
            token_info = agent_info.get("token_info") if agent_info is not None else None
            if token_info is not None:
                self.price_info[f"episode_{episode_id}"]["total_price"] += token_info["total"]["total_price"]
                self.price_info[f"episode_{episode_id}"]["prompt_price"] += token_info["prompt"]["prompt_price"]
                self.price_info[f"episode_{episode_id}"]["completion_price"] += token_info["completion"][
                    "completion_price"]

        self.steps[episode_id] += 1

    def episode_end(self, episode_id, env_info):
        self.results[f"episode_{episode_id}"]["result"] = env_info
        self.returns[episode_id] = np.array(env_info["returns"])
        self.logger.info(f"done: returns = {env_info['returns']}")
        self.logger.info(f"===== Episode {episode_id} end =====")
        for agent_id in self.human_agent:
            log_path = f"{self.human_agent_path}/player_{agent_id}_log.jsonl"
            if os.path.exists(log_path):
                with open(log_path, "a") as f:
                    episode_end_entry = {
                        "episode_end" : True,
                        "timestamp": time.time(),
                        "episode_id": episode_id, 
                        "env_info": env_info["returns"]
                    }
                    json.dump(episode_end_entry, f)
                    f.write("\n")

    def save_results(self):
        for i in range(self.num_agents):
            self.results["returns_std"][f"agent_{i}"] = np.std(self.returns[:, i])
            self.results["returns_mean"][f"agent_{i}"] = np.mean(self.returns[:, i])
            self.results["steps_std"] = np.std(self.steps)
            self.results["steps_mean"] = np.mean(self.steps)

        for i in range(self.num_episodes):
            self.results["price_info"]["total_price"] += self.price_info[f"episode_{i}"]["total_price"]
            self.results["price_info"]["prompt_price"] += self.price_info[f"episode_{i}"]["prompt_price"]
            self.results["price_info"]["completion_price"] += self.price_info[f"episode_{i}"]["completion_price"]

        self.logger.info(f"num_episodes: {self.num_episodes}")
        self.logger.info(f"returns_mean: {self.results['returns_mean']}")
        self.logger.info(f"returns_std: {self.results['returns_std']}")
        self.logger.info(f"steps_mean: {self.results['steps_mean']}")
        self.logger.info(f"steps_std: {self.results['steps_std']}")

        results = convert_numpy_types(self.results)
        with open(self.json_file, 'w') as f:
            json.dump(results, f, indent=4)
        self.logger.info(f"Results saved to {self.json_file}")

        for agent_id in self.human_agent:
            log_path = f"{self.human_agent_path}/player_{agent_id}_log.jsonl"
            if os.path.exists(log_path):
                with open(log_path, "a") as f:
                    game_end_entry = {
                        "game_end" : True,
                        "results" : [float(self.results['returns_mean'][f"agent_{i}"]) for i in range(self.num_agents)]
                    }
                    json.dump(game_end_entry, f)
                    f.write("\n")

    def close(self):
        """Close all handlers to ensure proper cleanup."""
        end_time = time.time()
        self.logger.info(f"Experiment completed in {end_time - self.start_time:.2f}s.")

        for handler in self.logger.handlers[:]:
            handler.close()
            self.logger.removeHandler(handler)

class PredictLogger:

    def __init__(self, config, env_name, num_prediction):
        self.start_time = time.time()
        self.exp_name = config["experiment"].get("name", "default")
        self.num_prediction = num_prediction
        self.env_name = env_name
        self.price_info = {
            f"predict_{i}": {
                "total_price": 0,
                "prompt_price": 0,
                "completion_price": 0,
            } for i in range(self.num_prediction)
        }
        self.results = {
            "config": config,
            "num_predicts": self.num_prediction,
            "accuracy": 0,
            "price_info": {
                "total_price": 0,
                "prompt_price": 0,
                "completion_price": 0,
            }
        }
        self.results.update({f"predict_{i}": dict() for i in range(self.num_prediction)})

        results_dir = config["experiment"].get("results_dir", "results")
        agent_config = config["agent"]
        type_ = agent_config["type"]
        model = agent_config["params"].get("model")
        if model is None:
            agent_name = type_
        else:
            model_name = model["params"]["name"]
            agent_name = f"{type_}({model_name})"

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_dir = os.path.join(results_dir, "strategic-reasoning", self.env_name, agent_name, self.exp_name,
                                    timestamp)
        self.log_file = os.path.join(self.log_dir, f"output.log")
        self.json_file = os.path.join(self.log_dir, f"results.json")
        os.makedirs(self.log_dir, exist_ok=True)

        # init logger
        logger_name = f"{self.env_name}_{agent_name}_{os.getpid()}_{timestamp}"
        self.logger = logging.getLogger(logger_name)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False
        formatter = logging.Formatter("%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

        file_handler = logging.FileHandler(self.log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(file_handler)

        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(formatter)
        self.logger.addHandler(console_handler)

        self.logger.info(f"Text logging to {self.log_file}")
        self.logger.info(f"Json logging to {self.json_file}")
        self.logger.info(f"Environment: {self.env_name}")
        self.logger.info(f"Agent: {agent_name}")
        self.logger.info(f"Experiment: {self.exp_name}")

    def log_predict(self, predict_id, predict_info: PredictInfo):
        self.results[f"predict_{predict_id}"] = {
            "choose_action": predict_info.choose_action,
            "target_actions": predict_info.target_actions,
            "reward": predict_info.reward,
            "agent_info": predict_info.agent_info,
        }
        # self.results[f"predict_{predict_id}"] = {
        #     "identified_board": predict_info.identified_board,
        #     "ground_truth_board": predict_info.ground_truth_board,
        #     "reward": predict_info.reward,
        #     "agent_info": predict_info.agent_info,
        # }
        self.results['accuracy'] += predict_info.reward
        token_info = predict_info.agent_info.get("token_info") if predict_info.agent_info is not None else None
        if token_info is not None:
            self.price_info[f"predict_{predict_id}"]["total_price"] += token_info["total"]["total_price"]
            self.price_info[f"predict_{predict_id}"]["prompt_price"] += token_info["prompt"]["prompt_price"]
            self.price_info[f"predict_{predict_id}"]["completion_price"] += token_info["completion"]["completion_price"]

        self.logger.info(f"===== Predict {predict_id} end. Reward = {predict_info.reward} =====")

    def save_results(self):
        for i in range(self.num_prediction):
            self.results["price_info"]["total_price"] += self.price_info[f"predict_{i}"]["total_price"]
            self.results["price_info"]["prompt_price"] += self.price_info[f"predict_{i}"]["prompt_price"]
            self.results["price_info"]["completion_price"] += self.price_info[f"predict_{i}"]["completion_price"]

        self.results["accuracy"] = (self.results["accuracy"] / self.num_prediction) * 100

        self.logger.info(f"num_prediction: {self.num_prediction}")
        self.logger.info(f"final accuracy: {self.results['accuracy']}")

        results = convert_numpy_types(self.results)
        with open(self.json_file, 'w') as f:
            json.dump(results, f, indent=4)
        self.logger.info(f"Results saved to {self.json_file}")

    def close(self):
        """Close all handlers to ensure proper cleanup."""
        for handler in self.logger.handlers[:]:
            handler.close()
            self.logger.removeHandler(handler)

class PerceptionLogger:
    def __init__(self, config, env_name, num_prediction):
        self.start_time = time.time()
        self.exp_name = config["experiment"].get("name", "default")
        self.num_prediction = num_prediction
        self.env_name = env_name

        self.price_info = {
            f"predict_{i}": {
                "total_price": 0,
                "prompt_price": 0,
                "completion_price": 0,
            } for i in range(self.num_prediction)
        }
        self.results = {
            "config": config,
            "num_predicts": self.num_prediction,
            "score": 0,
            "price_info": {
                "total_price": 0,
                "prompt_price": 0,
                "completion_price": 0,
            }
        }

        self.results.update({f"predict_{i}": dict() for i in range(self.num_prediction)})

        results_dir = config["experiment"].get("results_dir", "results")
        agent_config = config["agent"]
        type_ = agent_config["type"]
        model = agent_config["params"].get("model")
        if model is None:
            agent_name = type_
        else:
            model_name = model["params"]["name"]
            agent_name = f"{type_}({model_name})"

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_dir = os.path.join(results_dir, "perception", self.env_name, agent_name, self.exp_name,
                                    timestamp)
        self.log_file = os.path.join(self.log_dir, f"output.log")
        self.json_file = os.path.join(self.log_dir, f"results.json")
        os.makedirs(self.log_dir, exist_ok=True)

        # init logger
        logger_name = f"{self.env_name}_{agent_name}_{os.getpid()}_{timestamp}"
        self.logger = logging.getLogger(logger_name)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False
        formatter = logging.Formatter("%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

        file_handler = logging.FileHandler(self.log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(file_handler)

        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(formatter)
        self.logger.addHandler(console_handler)

        self.logger.info(f"Text logging to {self.log_file}")
        self.logger.info(f"Json logging to {self.json_file}")
        self.logger.info(f"Environment: {self.env_name}")
        self.logger.info(f"Agent: {agent_name}")
        self.logger.info(f"Experiment: {self.exp_name}")


    def log_predict(self, predict_id, perception_info: PerceptionInfo):
        self.results[f"predict_{predict_id}"] = {
            "model_output": perception_info.output,
            "label": perception_info.label,
            "reward": perception_info.reward,
            "agent_info": perception_info.agent_info,
        }
        self.results['score'] += perception_info.reward
        token_info = perception_info.agent_info.get("token_info") if perception_info.agent_info is not None else None
        if token_info is not None:
            self.price_info[f"predict_{predict_id}"]["total_price"] += token_info["total"]["total_price"]
            self.price_info[f"predict_{predict_id}"]["prompt_price"] += token_info["prompt"]["prompt_price"]
            self.price_info[f"predict_{predict_id}"]["completion_price"] += token_info["completion"]["completion_price"]

        self.logger.info(f"===== Predict {predict_id} end. Reward = {perception_info.reward} =====")

    def save_results(self):
        for i in range(self.num_prediction):
            self.results["price_info"]["total_price"] += self.price_info[f"predict_{i}"]["total_price"]
            self.results["price_info"]["prompt_price"] += self.price_info[f"predict_{i}"]["prompt_price"]
            self.results["price_info"]["completion_price"] += self.price_info[f"predict_{i}"]["completion_price"]

        self.results["score"] = (self.results["score"] / self.num_prediction) * 100

        self.logger.info(f"num_prediction: {self.num_prediction}")
        self.logger.info(f"final score: {self.results['score']}")

        results = convert_numpy_types(self.results)
        with open(self.json_file, 'w') as f:
            json.dump(results, f, indent=4)
        self.logger.info(f"Results saved to {self.json_file}")

    def close(self):
        """Close all handlers to ensure proper cleanup."""
        for handler in self.logger.handlers[:]:
            handler.close()
            self.logger.removeHandler(handler)


def convert_numpy_types(obj):
    """Convert numpy types to native Python types for JSON serialization."""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    return obj
