import datetime
import glob
import inspect
import os
import time

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

class Logger:
    def __init__(self, brain, config):
        self.config = config
        self.brain = brain
        self.log_dir = str(config.env_name)+datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if not hasattr(config,"log_dir") else config.log_dir
        self.start_time = 0
        self.duration = 0
        self.episode = 0
        self.running_iteration_logs = 0
        self.episode_ext_reward,self.episode_env_reward,self.running_ext_reward,self.running_env_reward=[0]*4
        self.max_episode_ext_reward = -np.inf
        self.max_episode_env_reward = -np.inf
        self.episode_info={}

        if not self.config.do_test:
            self.writer = SummaryWriter("Logs/" + self.log_dir)

        if not self.config.do_test and self.config.train_from_scratch:
            self.create_wights_folder()
            self.log_params()

        self.exp_avg = lambda x, y: 0.9 * x + 0.1 * y


    def create_wights_folder(self):
        if not os.path.exists("Models"):
            os.mkdir("Models")
        os.mkdir("Models/" + self.log_dir)

    def log_params(self):
        vars_dict={i[0]:i[1] for i in inspect.getmembers(self.config) if not inspect.ismethod(i[1]) and not i[0].startswith("_")}
        for k, v in vars_dict.items():
            self.writer.add_text(k, str(v))

    def on(self):
        self.start_time = time.time()

    def off(self):
        self.duration = time.time() - self.start_time

    def get_time_consumption_list(self):
        if not hasattr(self,"running_iteration_info_dict"):
            return [0.0]*5
        return [self.running_iteration_info_dict[f"Time Consumption/"+i] for i in
                ["Iteration Rollout Time","Iteration Training Time","RND State BN Time","Train Forward Time","Train Backward Time"]]


    def log_iteration(self, iteration,iteration_info_dict):
        info_keys=iteration_info_dict.keys()
        info_value_array=np.array([i for i in iteration_info_dict.values()])
        self.running_iteration_logs = self.exp_avg(self.running_iteration_logs, info_value_array)
        self.running_iteration_info_dict={key:info_value for key,info_value in zip(info_keys,info_value_array)}

        if iteration % (self.config.interval // 3) == 0:
            self.save_params(self.episode, iteration)
        if iteration % (self.config.interval *320) == 0:
            self.save_params_iter(self.episode, iteration)

        num_frames=iteration*self.config.n_workers*self.config.rollout_length
        for key,info_value in self.running_iteration_info_dict.items():
            self.writer.add_scalar(key,info_value,num_frames)
        self.writer.add_scalar("Performance/Episode Ext Reward", self.episode_ext_reward,num_frames )
        self.writer.add_scalar("Performance/Episode Ext Reward Ground Truth", self.episode_env_reward, num_frames)
        self.writer.add_scalar("Performance/Running Episode Ext Reward", self.running_ext_reward, num_frames)
        self.writer.add_scalar("Performance/Running Episode Ext Reward Ground Truth", self.running_env_reward, num_frames)
        self.writer.add_scalar("Performance/Max Episode Ext Reward", self.max_episode_ext_reward,num_frames)
        for k1,v1 in self.episode_info.items():
            if not(isinstance(v1,list) or isinstance(v1,dict) or isinstance(v1,set)):
                self.writer.add_scalar(f"Episode Info/{k1}",v1,num_frames)

        self.off()
        if iteration % (self.config.interval) == 0:
            print("Iter:{}| "
                  "EP:{}| "
                  "Reward:{}/{:.3f}/{}/{:.3f}| "
                  "EpInfo:{}"
                  "Duration:{:.3f}| "
                  "TimeCost: Rollout-{:.4f} Train-{:.4f} RND StateBN-{:.4f} Train Forward-{:.4f} Train Backward-{:.4f}| "
                  "Time:{} "
                  .format(iteration,
                          self.episode,
                          self.episode_ext_reward,
                          self.running_ext_reward,
                          self.episode_env_reward,
                          self.running_env_reward,
                          self.episode_info,
                          self.duration,
                          *self.get_time_consumption_list(),
                          datetime.datetime.now().strftime("%H:%M:%S"),
                          ))
        self.on()

    def log_episode(self, episode,episode_info_dict):
        self.episode = episode
        self.episode_ext_reward=episode_info_dict["ext_reward"]
        self.episode_env_reward=episode_info_dict["env_reward"]
        episode_info=episode_info_dict['info'].get("episode",{})
        self.episode_info=episode_info

        self.max_episode_ext_reward = max(self.max_episode_ext_reward, self.episode_ext_reward)
        self.max_episode_env_reward = max(self.max_episode_env_reward, self.episode_env_reward)

        self.running_ext_reward = self.exp_avg(self.running_ext_reward, self.episode_ext_reward)
        self.running_env_reward = self.exp_avg(self.running_env_reward, self.episode_env_reward)

    def save_params(self, episode, iteration):
        infos={"iteration": iteration,
                    "episode": episode,
                    "running_reward": self.running_ext_reward
        }
        self.brain.save_model("Models/" + self.log_dir + "/params.pth",infos)

    def save_params_iter(self, episode, iteration):
        infos = {"iteration": iteration,
                 "episode": episode,
                 "running_reward": self.running_ext_reward
        }
        self.brain.save_model("Models/" + self.log_dir + "/iter"+str(iteration)+"_episode"+str(episode)+"params.pth", infos)



    def load_weights(self):
        model_dir = glob.glob("Models/*")
        model_dir.sort()
        try:
            checkpoint = torch.load(model_dir[-1] + "/params.pth")
        except:
            checkpoint = torch.load(model_dir[-1] + "/params.pth",map_location=torch.device('cpu'))
        self.log_dir = model_dir[-1].split(os.sep)[-1]
        return checkpoint




