import wandb
import os
import numpy as np
import torch
from tensorboardX import SummaryWriter
from code_ptmc_mappo.utils.shared_buffer import SharedReplayBuffer
import logging
import sys
from datetime import datetime

def _t2n(x):
    """Convert torch tensor to a numpy array."""
    return x.detach().cpu().numpy()

class Runner(object):
    """
    Base class for training recurrent policies.
    :param config: (dict) Config dictionary containing parameters for training.
    """
    def __init__(self, config):

        self.all_args = config['all_args']
        self.envs = config['envs']
        self.eval_envs = config['eval_envs']
        self.device = config['device']
        self.num_agents = config['num_agents']
        if config.__contains__("render_envs"):
            self.render_envs = config['render_envs']

        # parameters
        self.env_name = self.all_args.env_name
        self.algorithm_name = self.all_args.algorithm_name
        self.experiment_name = self.all_args.experiment_name
        self.use_centralized_V = self.all_args.use_centralized_V
        self.use_obs_instead_of_state = self.all_args.use_obs_instead_of_state
        self.num_env_steps = self.all_args.num_env_steps
        self.episode_length = self.all_args.episode_length
        self.n_rollout_threads = self.all_args.n_rollout_threads
        self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads
        self.n_render_rollout_threads = self.all_args.n_render_rollout_threads
        self.use_linear_lr_decay = self.all_args.use_linear_lr_decay
        self.hidden_size = self.all_args.hidden_size
        self.use_wandb = self.all_args.use_wandb
        self.use_render = self.all_args.use_render
        self.recurrent_N = self.all_args.recurrent_N
        self.last_saved_step = 0

        # interval
        self.save_interval = self.all_args.save_interval
        self.use_eval = self.all_args.use_eval
        self.eval_interval = self.all_args.eval_interval
        self.log_interval = self.all_args.log_interval

        # dir
        self.model_dir = self.all_args.pretr_model_dir
        self.tacit_model_dir = self.all_args.tacit_model_dir
        self.save_replay = self.all_args.save_replay

        # pre_trained parameters:
        self.param_alpha = self.all_args.param_alpha

        if self.use_wandb:
            self.save_dir = str(wandb.run.dir)
            self.run_dir = str(wandb.run.dir)
        else:
            self.run_dir = config["run_dir"]
            self.log_dir = str(self.run_dir / 'logs')
            if not os.path.exists(self.log_dir):
                os.makedirs(self.log_dir)
            self.writter = SummaryWriter(self.log_dir)
            self.save_dir = str(self.run_dir / 'models')
            if not os.path.exists(self.save_dir):
                os.makedirs(self.save_dir)

        if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
            from algorithms.mat.mat_trainer import MATTrainer as TrainAlgo
            from algorithms.mat.algorithm.transformer_policy import TransformerPolicy as Policy
        else:
            from algorithms.r_mappo.r_mappo import R_MAPPO as TrainAlgo
            from algorithms.r_mappo.algorithm.rMAPPOPolicy import R_MAPPOPolicy as Policy

        share_observation_space = self.envs.share_observation_space[0] if self.use_centralized_V else self.envs.observation_space[0]

        print("obs_space: ", self.envs.observation_space)
        print("share_obs_space: ", self.envs.share_observation_space)
        print("act_space: ", self.envs.action_space)
        
        # policy network
        if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
            self.policy = Policy(self.all_args, self.envs.observation_space[0], share_observation_space, self.envs.action_space[0], self.num_agents, device = self.device)
        else:
            self.policy = Policy(self.all_args, self.envs.observation_space[0], share_observation_space, self.envs.action_space[0], device = self.device)

        if self.model_dir is not None:
            self.restore(self.model_dir)

        # algorithm
        if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
            self.trainer = TrainAlgo(self.all_args, self.policy, self.num_agents, device = self.device)
        else:
            self.trainer = TrainAlgo(self.all_args, self.policy, device = self.device)
        
        # buffer
        self.buffer = SharedReplayBuffer(self.all_args,
                                        self.num_agents,
                                        self.envs.observation_space[0],
                                        share_observation_space,
                                        self.envs.action_space[0])

    def run(self):
        """Collect training data, perform training updates, and evaluate policy."""
        raise NotImplementedError

    def warmup(self):
        """Collect warmup pre-training data."""
        raise NotImplementedError

    def collect(self, step):
        """Collect rollouts for training."""
        raise NotImplementedError

    def insert(self, data):
        """
        Insert data into buffer.
        :param data: (Tuple) data to insert into training buffer.
        """
        raise NotImplementedError
    
    @torch.no_grad()
    def compute(self):
        """Calculate returns for the collected data."""
        self.trainer.prep_rollout()
        if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
            next_values = self.trainer.policy.get_values(np.concatenate(self.buffer.share_obs[-1]),
                                                        np.concatenate(self.buffer.obs[-1]),
                                                        np.concatenate(self.buffer.rnn_states_critic[-1]),
                                                        np.concatenate(self.buffer.masks[-1]))
        else:
            next_values = self.trainer.policy.get_values(np.concatenate(self.buffer.share_obs[-1]),
                                                        np.concatenate(self.buffer.rnn_states_critic[-1]),
                                                        np.concatenate(self.buffer.masks[-1]))
        next_values = np.array(np.split(_t2n(next_values), self.n_rollout_threads))
        self.buffer.compute_returns(next_values, self.trainer.value_normalizer)
    
    def train(self,total_num_steps=0):
        """Train policies with data in buffer. """
        self.trainer.prep_training()
        train_infos = self.trainer.train(self.buffer, total_num_steps)
        self.buffer.after_update()
        return train_infos

    def save(self, episode=0, total_num_steps=0):
        """Save policy's actor and critic networks."""
        if self.all_args.env_name == 'StagHunt':
            save_steps = 50000
            pr_save_steps = 5000
        elif self.all_args.env_name == "StarCraft2v2":
            save_steps = 2000000
            pr_save_steps = 10000
        elif self.all_args.env_name == "StarCraft2":
            save_steps = 5000000
            pr_save_steps = 100000

        if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
            if total_num_steps // save_steps > self.last_saved_step // save_steps:
                self.policy.save(self.save_dir, episode)
                self.last_saved_step = total_num_steps
        else:
            policy_actor = self.trainer.policy.actor
            torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor.pt")
            policy_critic = self.trainer.policy.critic
            torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic.pt")
            if self.algorithm_name == "ippo":
                if total_num_steps // pr_save_steps > self.last_saved_step // pr_save_steps:
                    actor_save_path = str(self.save_dir) + f"{total_num_steps}/actor.pt"
                    os.makedirs(os.path.dirname(actor_save_path), exist_ok=True)
                    torch.save(policy_actor.state_dict(), actor_save_path)
                    critic_save_path = str(self.save_dir) + f"{total_num_steps}/critic.pt"
                    os.makedirs(os.path.dirname(critic_save_path), exist_ok=True)
                    torch.save(policy_critic.state_dict(), critic_save_path)
                    self.last_saved_step = total_num_steps
            if (self.algorithm_name == "ptmc" or self.algorithm_name == "mappo"):
                if total_num_steps // save_steps > self.last_saved_step // save_steps:
                    actor_save_path = str(self.save_dir) + f"{total_num_steps}/actor.pt"
                    os.makedirs(os.path.dirname(actor_save_path), exist_ok=True)
                    torch.save(policy_actor.state_dict(), actor_save_path)
                    critic_save_path = str(self.save_dir) + f"{total_num_steps}/critic.pt"
                    os.makedirs(os.path.dirname(critic_save_path), exist_ok=True)
                    torch.save(policy_critic.state_dict(), critic_save_path)
                    self.last_saved_step = total_num_steps

    def restore(self, model_dir):
        """Restore policy's networks from a saved model."""
        if self.algorithm_name == "ptmc" :
            policy_actor_state_dict = torch.load(str(model_dir) + '/actor.pt')
            self.policy.actor.load_state_dict(policy_actor_state_dict)
        elif self.save_replay:
            policy_actor_state_dict = torch.load(str(model_dir) + '/actor.pt')
            self.policy.actor.load_state_dict(policy_actor_state_dict)
            if not self.all_args.use_render:
                policy_critic_state_dict = torch.load(str(model_dir) + '/critic.pt')
                self.policy.critic.load_state_dict(policy_critic_state_dict)
        else:
            policy_actor_state_dict = torch.load(str(model_dir) + '/actor.pt')
            self.policy.actor.load_state_dict(policy_actor_state_dict)
            if not self.all_args.use_render:
                policy_critic_state_dict = torch.load(str(model_dir) + '/critic.pt')
                self.policy.critic.load_state_dict(policy_critic_state_dict)

    def log_train(self, train_infos, total_num_steps):
        """
        Log training info.
        :param train_infos: (dict) information about training update.
        :param total_num_steps: (int) total number of training env steps.
        """
        logger = self.log_creater()
        for k, v in train_infos.items():
            if self.use_wandb:
                wandb.log({k: v}, step=total_num_steps)
            else:
                logger.info("Logging %s: %s at step %d",k, {k: v}, total_num_steps)
                self.writter.add_scalars(k, {k: v}, total_num_steps)

    def log_env(self, env_infos, total_num_steps):
        """
        Log env info.
        :param env_infos: (dict) information about env state.
        :param total_num_steps: (int) total number of training env steps.
        """
        logger = self.log_creater()
        for k, v in env_infos.items():
            if len(v)>0:
                if self.use_wandb:
                    wandb.log({k: np.mean(v)}, step=total_num_steps)
                else:
                    logger.info("Logging %s: %s at step %d", k, {k: np.mean(v)}, total_num_steps)
                    self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps)

    def log_creater(self):
        logger = logging.getLogger(__file__)
        logger.setLevel(logging.DEBUG)

        # 建立两个handlers，并设置打印日志级别
        if not logger.handlers:
            log_file_dir = str(self.run_dir / 'log_write')
            os.makedirs(log_file_dir, exist_ok=True)

            log_path = os.path.join(log_file_dir, 'log.log')
            # 只在首次创建log文件时写入命令行信息和分隔线
            if not os.path.exists(log_path):
                with open(log_path, 'w') as f:
                    full_cmd = " ".join(sys.argv)
                    f.write(f"# Time: {datetime.now()}\n")
                    f.write(full_cmd + "\n")
                    f.write("-------------------------------------------------\n")
                    f.write("-------------------------------------------------\n")

            fileHandler = logging.FileHandler(log_path, mode='a')
            fileHandler.setLevel(logging.DEBUG)
            consoleHandler = logging.StreamHandler()
            consoleHandler.setLevel(logging.DEBUG)

            # 设置日志格式
            formatter = logging.Formatter('[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
                                          datefmt='%Y-%m-%d %H:%M:%S')
            consoleHandler.setFormatter(formatter)
            fileHandler.setFormatter(formatter)

            # 将相应的handler 添加在logger对象中
            logger.addHandler(fileHandler)
            logger.addHandler(consoleHandler)
        return logger

    def log_tacit(self, tacit_indicator, total_num_steps):
        logger = self.log_creater()
        row_names = ["mean_r_tacit", "num_all", "num_positive", "tacit_indicator"]
        for i, row_name in enumerate(row_names):
            row_values = tacit_indicator[i]
            logger.info(f"Logging {row_name}: {row_values.tolist()}, at step {total_num_steps}")
            last_row = tacit_indicator[-1]
            tag_names = ["tacit_1_ind", "tacit_2_ind", "tacit_3_ind", "tacit_4_ind"]
            self.writter.add_scalars(
                main_tag="tacit_indicator",
                tag_scalar_dict={name: value for name, value in zip(tag_names, last_row)},
                global_step= total_num_steps
            )
