import time

from torch.multiprocessing import Process, Pipe
from tqdm import tqdm

from Algorithms.Common.brain_common import Brain, BaseWorkflowController
from Common import Config
from Common.logger.logger_general import Logger
from Common.runner import Worker


def run_workers(worker, conn):
    worker.step(conn)


class BaseWorkflow:
    desc:str="base"
    brain:Brain

    def __init__(self, config:Config):
        self.config=config
        self.brain=self.get_brain_cls()(config)
        self.logger=Logger(self.brain,config)
        if not config.train_from_scratch:
            checkpoint= self.logger.load_weights()
            self.brain.set_from_checkpoint(checkpoint)
            running_ext_reward = checkpoint["running_reward"]
            self.init_iteration = checkpoint["iteration"]
            self.episode = checkpoint["episode"]
            self.logger.running_ext_reward = running_ext_reward
            self.logger.episode = self.episode
        else:
            self.episode=0
            self.init_iteration = 0
        self.init_workers()
        self.init_wf_controller()

    def get_worker_cls(self):
        return Worker

    def get_brain_cls(self):
        return Brain

    def init_wf_controller(self):
        self.wf_controller=BaseWorkflowController(self.config)

    def init_workers(self):
        self.workers=[self.get_worker_cls()(i,self.config) for i in range(0,self.config.n_workers)]
        self.parents = []
        for worker in self.workers:
            parent_conn, child_conn = Pipe()
            p = Process(target=run_workers, args=(worker, child_conn,))
            p.daemon = True
            self.parents.append(parent_conn)
            p.start()

        self.logger.on()
        self.episode_ext_reward = 0
        self.episode_ext_reward_gt = 0

    def roll_s(self,t):
        for worker_id, parent in enumerate(self.parents):
            s_prev = parent.recv()
            self.wf_controller.states[worker_id, t] = s_prev
            self.wf_controller.obs_states[worker_id, t] = s_prev[-1:,...] if self.config.env_type == "Atari" else s_prev
            self.roll_s_spec(worker_id,t,s_prev)

    def roll_s_spec(self,id,t,s_prev):
        pass

    def roll_a(self,t):
        if not self.config.policy_use_rnn:
            (
                self.wf_controller.actions[:, t],
                self.wf_controller.int_values[:, t],
                self.wf_controller.ext_values[:, t],
                self.wf_controller.log_probs[:, t],
                self.wf_controller.action_probs[:, t],
                _
            ) = self.brain.get_actions_and_values(self.wf_controller.states[:, t], batch=True)
        else:
            self.wf_controller.hidden_states[:,t]=self.wf_controller.rollout_last_hidden_states
            (
                self.wf_controller.actions[:, t],
                self.wf_controller.int_values[:, t],
                self.wf_controller.ext_values[:, t],
                self.wf_controller.log_probs[:, t],
                self.wf_controller.action_probs[:, t],
                _,
                self.wf_controller.rollout_last_hidden_states[:],
            ) = self.brain.get_actions_and_values_rnn(self.wf_controller.states[:, t],self.wf_controller.rollout_last_hidden_states[:], batch=True)
        self.send_a(t)

    def send_a(self,t):
        for worker_id, parent, a in zip(range(len(self.parents)), self.parents, self.wf_controller.actions[:, t]):
            parent.send(a)

    def roll_next(self,t):
        infos = []
        for worker_id, parent in enumerate(self.parents):
            s_, r, d, info = parent.recv()
            infos.append(info)
            self.wf_controller.ext_rewards[worker_id, t] = r
            self.wf_controller.dones[worker_id, t] = d
            self.wf_controller.rollout_last_states[worker_id] = s_
            self.wf_controller.next_states[worker_id, t] = s_
            self.wf_controller.obs_next_states[worker_id, t] = s_[-1:, ...] if self.config.env_type == "Atari" else s_

            if "cell_state" in info:
                self.wf_controller.cell_states[worker_id, t] = info["cell_state"]
                self.wf_controller.cell_next_states[worker_id, t] = info["next_cell_state"]
            elif self.config.cell_shape==self.config.state_shape:
                self.wf_controller.cell_states[worker_id, t] = self.wf_controller.states[worker_id, t]
                self.wf_controller.cell_next_states[worker_id, t] = self.wf_controller.next_states[worker_id, t]
            else :
                self.wf_controller.cell_states[worker_id, t] = self.wf_controller.obs_states[worker_id, t]
                self.wf_controller.cell_next_states[worker_id, t] = self.wf_controller.obs_next_states[worker_id, t]



            self.wf_controller.episode_steps[worker_id, t] = info['episode_step']
            self.wf_controller.ext_rewards_gt[worker_id, t] = info['gt_reward']
            self.wf_controller.erir_mask[worker_id, t] = info['erir_mask']
            if self.config.policy_use_rnn:
                self.wf_controller.rollout_last_hidden_states[worker_id]*=(1-d)
            self.roll_t_spec(worker_id, t,s_,info)
        return infos

    def roll_t_spec(self,worker_id,t,s_,info):###
        pass

    def roll_extra(self,t):###
        pass

    def rollout(self):
        for t in range(self.config.rollout_length):
            self.roll_s(t)
            self.roll_a(t)
            infos=self.roll_next(t)
            self.roll_extra(t)

            self.episode_ext_reward += self.wf_controller.ext_rewards[0, t]
            self.episode_ext_reward_gt += self.wf_controller.ext_rewards_gt[0, t]
            if self.wf_controller.dones[0, t]:
                self.episode += 1
                self.logger.log_episode(self.episode, {
                    "ext_reward": self.episode_ext_reward,
                    "env_reward": self.episode_ext_reward_gt,
                    "info": infos[0]
                })
                self.episode_ext_reward = 0
                self.episode_ext_reward_gt = 0

    def get_intrinsic(self):
        self.wf_controller.int_rewards, int_info = self.brain.calc_rollout_int_reward(self.wf_controller)
        self.get_intrinsic_post(int_info)


    def get_intrinsic_post(self,info):###
        pass
        # self.wf_controller.sla_values, self.wf_controller.sla_next_values = sla_reward_info["sla_value"], \
        # sla_reward_info["sla_next_value"]
        # self.wf_controller.calc_sla_gae()

    def rollout_post(self):
        if not self.config.policy_use_rnn:
            _, self.wf_controller.rollout_last_int_values, self.wf_controller.rollout_last_ext_values, *_ = self.brain.get_actions_and_values(
                self.wf_controller.rollout_last_states, batch=True)
        else:
            _, self.wf_controller.rollout_last_int_values, self.wf_controller.rollout_last_ext_values, *_ = self.brain.get_actions_and_values_rnn(
                self.wf_controller.rollout_last_states,self.wf_controller.rollout_last_hidden_states, batch=True)
        self.wf_controller.calc_gae()

    def get_rollout_log(self):
        rollout_log_dict={
            "Performance/Mean Intrinsic Reward": self.wf_controller.int_rewards[0].mean(),
            "Performance/Mean Action Stochasticity": self.wf_controller.action_probs[0].max(-1).mean(),
        }
        return rollout_log_dict

    def run(self):
        for iteration in tqdm(range(self.init_iteration + 1, self.config.total_rollouts_per_env + 1),desc=self.desc+self.config.env_name):
            self.wf_controller.reset()

            rollout_start_time = time.time()

            self.rollout()
            self.get_intrinsic()
            self.rollout_post()

            training_start_time = time.time()
            LOG_rollout_time=training_start_time-rollout_start_time

            logger_iteration_info_dict = self.brain.train(self.wf_controller)

            training_end_time = time.time()
            LOG_train_time = training_end_time - training_start_time

            logger_iteration_info_dict.update({
                "Time Consumption/Iteration Rollout Time": LOG_rollout_time,
                "Time Consumption/Iteration Training Time": LOG_train_time,
            })
            logger_iteration_info_dict.update(self.get_rollout_log())
            self.logger.log_iteration(iteration,
                                 logger_iteration_info_dict,
            )

def main_common(config:Config):
    WF=BaseWorkflow(config)
    WF.run()

