import numpy as np
from torch.utils.tensorboard import SummaryWriter

class RunningMeanStd(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean, self.var = mean, std
        self.count = 0

    def update(self, data_array):
        batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
        batch_count = len(data_array)

        delta = batch_mean - self.mean
        total_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / total_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count
        new_var = m_2 / total_count

        self.mean, self.var = new_mean, new_var
        self.count = total_count

class Logger(object):
    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.writer = SummaryWriter(log_dir=log_dir)

    def write(self, step, data):
        for k, v in data.items():
            self.writer.add_scalar(k, v, global_step=step)
        self.writer.flush()   

    def log_train_data(self, collect_result, step):
        if collect_result["rew"] != 0:
            log_data = {}
            
            rews = collect_result["rews"]
            lens = collect_result["lens"]
            # log for all env: just useful when compare_ntl.py
            all_mean_rew = rews.mean()
            if all_mean_rew is not np.ma.masked:
                log_data["train/reward"] = all_mean_rew
                
            # log for the former 32 env
            half_mean_rew = rews[:32].mean()
            if half_mean_rew is not np.ma.masked:
                log_data["train/sreward"] = half_mean_rew
                log_data["train/sreward_std"] = rews[0:32].std()
            

            # log for the latter 32 env
            latter_half_mean_rew = rews[-32:].mean()
            if latter_half_mean_rew is not np.ma.masked:
                log_data["train/treward"] = latter_half_mean_rew
                log_data["train/treward_std"] = rews[-32:].std()
            for id in range(32, 64, 4):
                id_mean_rew = rews[id: id+4].mean()
                if id_mean_rew is not np.ma.masked:
                    log_data[f'train/treward{(id-32)//4+1}'] = id_mean_rew

            # log_data = {
            #     "train/sreward": rews[0:32].mean(),
            #     "train/sreward_std": rews[0:32].std(),
            #     "train/treward1": rews[32:36].mean(),
            #     "train/treward2": rews[36:40].mean(),
            #     "train/treward3": rews[40:44].mean(),
            #     "train/treward4": rews[44:48].mean(), 
            #     "train/treward5": rews[48:52].mean(),
            #     "train/treward6": rews[52:56].mean(),
            #     "train/treward7": rews[56:60].mean(),
            #     "train/treward8": rews[60:64].mean(),               
            # }
            self.write(step, log_data)

    def log_update_data(self, update_result, step):
        total_step = len(update_result["loss"])
        # for i in range(total_step):
        #     log_data = {
        #         "update/loss": update_result["loss"][i],
        #         "update/clip_loss": update_result["loss/clip"][i], 
        #         "update/vf_loss": update_result["loss/vf"][i],
        #         "update/ent_loss": update_result["loss/ent"][i],
        #         "update/ratio": update_result["ratio"][i],
        #         "update/kl": update_result["kl"][i],
        #     }   
        #     self.write(step+i,log_data)
        log_data = {
            "update/mean/loss": np.array(update_result["loss"]).mean(),
            "update/mean/clip_loss": np.array(update_result["loss/clip"]).mean(), 
            "update/mean/vf_loss": np.array(update_result["loss/vf"]).mean(),
            "update/mean/ent_loss": np.array(update_result["loss/ent"]).mean(),
            "update/mean/ratio": np.array(update_result["ratio"]).mean(),
            "update/mean/kl": np.array(update_result["kl"]).mean(),
        }   
        self.write(step,log_data)     

    def log_test_data(self, test_result, step):
        log_data = {
            "test/reward": test_result["test_reward"],
            "test/reward_std": test_result["test_reward_std"],
            "test/best_reward": test_result["best_reward"],
        }
        self.write(step, log_data)                   
