import numpy as np
import torch
import tensordict
from tensordict import TensorDict

import tools


class OnlineTrainer:
    def  __init__(self, config, replay_buffer, logger, logdir, train_envs, eval_envs, act_dim):
        self.replay_buffer = replay_buffer
        self.logger = logger
        self.logdir = logdir
        self.train_envs = train_envs
        self.eval_envs = eval_envs
        self.act_dim = act_dim

        self.steps = int(config.steps)
        self.pretrain = int(config.pretrain)
        self.eval_every = int(config.eval_every)
        self.eval_episode_num = int(config.eval_episode_num)
        self.video_pred_log = bool(config.video_pred_log)
        self.params_hist_log = bool(config.params_hist_log)
        self.batch_length = int(config.batch_length)
        batch_steps = int(config.batch_size * config.batch_length)
        # train_ratio is based on data steps rather than environment steps.
        self._updates_needed = tools.Every(batch_steps / config.train_ratio * config.action_repeat)
        self._should_pretrain = tools.Once()
        self._should_log = tools.Every(config.update_log_every)
        self._should_eval = tools.Every(self.eval_every)
        self._action_repeat = config.action_repeat

    def to_td(self, transition):
        if "reward" not in transition:
            transition.update(reward=torch.zeros((1,), dtype=torch.float32))
        transition = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in transition.items()}
        td = TensorDict(transition, batch_size=())
        for key in td.keys():
            if td[key].ndim == 0 and key != "episode":
                td[key] = td[key].unsqueeze(-1)
        return td.unsqueeze(0)

    def eval(self, agent, train_step):
        print("Evaluating the policy...")
        envs = self.eval_envs
        agent.eval()
        done = torch.ones(envs.env_num, dtype=torch.bool, device=agent.device)
        once_done = torch.zeros(envs.env_num, dtype=torch.bool, device=agent.device)
        steps = torch.zeros(envs.env_num, dtype=torch.int32, device=agent.device)
        returns = torch.zeros(envs.env_num, dtype=torch.float32, device=agent.device)
        cache = []
        agent_state = agent.get_initial_state(envs.env_num)
        act = agent_state["prev_action"].clone() # (B, A)
        while not once_done.all():
            steps += ~done * ~once_done
            # Step envs on CPU to avoid GPU<->CPU sync in the worker processes
            act_cpu = act.detach().to('cpu')
            done_cpu = done.detach().to('cpu')
            trans_cpu, done_cpu = envs.step(act_cpu, done_cpu)
            # Move observations back to GPU asynchronously for the agent
            trans = trans_cpu.to(agent.device, non_blocking=True)
            done = done_cpu.to(agent.device)
            # The observation and the action "leads to it" are stored together.
            trans["action"] = act
            cache.append(trans.clone())
            act, agent_state = agent.act(trans, agent_state, eval=True)
            returns += trans["reward"][:, 0] * ~once_done
            once_done |= done
        cache = torch.stack(cache, dim=1)
        self.logger.scalar(f"episode/eval_score", returns.mean())
        self.logger.scalar(f"episode/eval_length", steps.to(torch.float32).mean())
        if "image" in cache:
            self.logger.video(f"eval_video", tools.to_np(cache["image"][:1]))
        if self.video_pred_log:
            initial = agent.get_initial_state(1)
            self.logger.video("eval_open_loop", tools.to_np(agent.video_pred(cache[:1, :self.batch_length], (initial["stoch"], initial["deter"]))))
        self.logger.write(train_step)
        agent.train()

    def begin(self, agent):
        envs = self.train_envs
        video_cache = []
        step = self.replay_buffer.count() * self._action_repeat
        update_count = 0
        done = torch.ones(envs.env_num, dtype=torch.bool, device=agent.device)
        returns = torch.zeros(envs.env_num, dtype=torch.float32, device=agent.device)
        lengths = torch.zeros(envs.env_num, dtype=torch.int32, device=agent.device)
        episode_ids = torch.arange(envs.env_num, dtype=torch.int32, device=agent.device) # Increment this to prevent sampling across episode boundaries
        train_metrics = {}
        agent_state = agent.get_initial_state(envs.env_num)
        act = agent_state["prev_action"].clone() # (B, A)
        while step < self.steps:
            # Evaluation
            if self._should_eval(step) and self.eval_episode_num > 0:
                self.eval(agent, step)
            # Save metrics
            if done.any():
                for i, d in enumerate(done):
                    if d and lengths[i] > 0:
                        if i == 0 and len(video_cache) > 0:
                            video = torch.stack(video_cache, axis=0)
                            self.logger.video(f"train_video", tools.to_np(video[None]))
                            video_cache = []
                        self.logger.scalar(f"episode/score", returns[i])
                        self.logger.scalar(f"episode/length", lengths[i])
                        self.logger.write(step + i) # to show all values on tensorboard
                        returns[i] = lengths[i] = 0
            step +=  int((~done).sum()) * self._action_repeat # step is based on env side
            lengths += ~done
            # Step envs on CPU to avoid GPU<->CPU sync in the worker processes
            act_cpu = act.detach().to('cpu')
            done_cpu = done.detach().to('cpu')
            trans_cpu, done_cpu = envs.step(act_cpu, done_cpu)
            # Move observations back to GPU asynchronously for the agent
            trans = trans_cpu.to(agent.device, non_blocking=True)
            done = done_cpu.to(agent.device)
            # "agent_state" is initialized based on the "is_first" flag in trans.
            act, agent_state = agent.act(trans.clone(), agent_state, eval=False)
            # Store transition. The observation and the action "taken from it" are stored together.
            trans["action"] = act * ~done.unsqueeze(-1) # Mask action after done
            trans["stoch"] = agent_state["stoch"]
            trans["deter"] = agent_state["deter"]
            trans["episode"] = episode_ids # Don't lift dim
            if "image" in video_cache:
                video_cache.append(trans["image"][0])
            self.replay_buffer.add_transition(trans.detach())
            returns += trans["reward"][:, 0]
            # Update models after enough data has accumulated
            if step // (envs.env_num * self._action_repeat) > self.batch_length + 1:
                if self._should_pretrain():
                    update_num = self.pretrain
                else:
                    update_num = self._updates_needed(step)
                for _ in range(update_num):
                    _metrics = agent.update(self.replay_buffer)
                    train_metrics = _metrics
                update_count += update_num
                # Log training metrics
                if self._should_log(step):
                    for name, value in train_metrics.items():
                        value = tools.to_np(value) if isinstance(value, torch.Tensor) else value
                        self.logger.scalar(f"train/{name}", value)
                    self.logger.scalar(f"train/opt/updates", update_count)
                    if self.video_pred_log:
                        data, _, initial = self.replay_buffer.sample()
                        self.logger.video("open_loop", tools.to_np(agent.video_pred(data, initial)))
                    if self.params_hist_log:
                        for name, param in agent._named_params.items():
                            self.logger.histogram(f"{name}", tools.to_np(param))
                    self.logger.write(step, fps=True)
