import os
from pathlib import Path
from typing import Dict, Iterable, Iterator, Optional, Sequence, Tuple

import numpy as np
import torch
from dm_env import specs
from lambda_ac.replay.per_replay_buffer import PERDataBuffer

from lambda_ac.replay.replay_buffer import ReplayBufferStorage, make_replay_loader
from lambda_ac.replay.replay_memory import DataBuffer
from lambda_ac.rl_types import ActorCriticAgent, Agent, ExplorationScheduler


class EnvBufferWrapper:
    def __init__(self, env, eval_env, buffer_size, device, args=None):
        self.env = env
        self.eval_env = eval_env
        self.buffer_size = buffer_size
        self.device = device

    def sample(self, batch_size):
        pass

    def act(self, agent):
        pass

    def save(self, path):
        pass

    def load(self, path):
        pass


class DrQEnvBufferWrapper(EnvBufferWrapper):
    def __init__(self, env, eval_env, buffer_size, device, args):
        super().__init__(env, eval_env, buffer_size, device)

        self.data_specs = (
            env.observation_spec(),
            env.action_spec(),
            specs.Array((1,), np.float32, "reward"),
            specs.Array((1,), np.float32, "discount"),
        )
        self.replay_storage = ReplayBufferStorage(
            self.data_specs,
            Path.cwd() / "replay_memory",
        )

        self.memory = make_replay_loader(
            Path.cwd() / "replay_memory",
            args.replay_size,
            args.model_batch_size,
            args.num_memory_workers,
            True,
            max(
                args.agent.actor_rollout_depth,
                args.agent.critic_rollout_depth,
                args.agent.model_train_depth,
            ),
            args.discount,
        )
        self.time_step = self.env.reset()
        self.is_reset = True
        self.step_count = 0
        self.episode_count = 0

        self.reward_aggregate = 0

    def init_memory(self):
        self.memory = iter(self.memory)

    def get_batch(self):
        return self.memory

    @torch.no_grad()
    def step(
        self,
        agent: Agent,
        exploration_scheduler: Optional[ExplorationScheduler],
    ):
        action = agent.select_action(
            self.time_step.observation,
            eval=False,
            step=self.step_count,
            episode=self.episode_count,
        )
        if exploration_scheduler is not None:
            action = exploration_scheduler(action)
        next_time_step = self.env.step(action)
        self.replay_storage.add(next_time_step)
        self.reward_aggregate += next_time_step.reward
        if next_time_step.last():
            self.time_step = self.env.reset()
            self.is_reset = True
            episode_reward = self.reward_aggregate
            self.reward_aggregate = 0
            self.step_count = 0
            self.episode_count += 1
            return True, episode_reward
        self.time_step = next_time_step
        self.step_count += 1
        return False, self.reward_aggregate

    @torch.no_grad()
    def eval(self, agent, trajectories):
        rewards = 0
        for _ in range(trajectories):
            episode_reward = 0
            time_step = self.eval_env.reset()
            step = 0
            while not self.time_step.last():
                action = agent.select_action(
                    time_step.observation, eval=True, episode=self.episode_count
                )
                time_step = self.eval_env.step(action)
                episode_reward += time_step.reward
                step += 1
            rewards += episode_reward
        return rewards / trajectories

    def save(self, path):
        pass

    def load(self, path):
        pass


class GymEnvBufferWrapper(EnvBufferWrapper):
    def __init__(self, env, eval_env, buffer_size, device, args):
        super().__init__(env, eval_env, buffer_size, device)
        self.buffer = PERDataBuffer(
            buffer_size,
            env,
            device,
            return_done=True,
            depth=max(
                args.agent.actor_rollout_depth,
                args.agent.critic_rollout_depth,
                args.agent.model_train_depth,
            ),
        )
        self.state = self.env.reset()
        self.reward_aggregate = 0
        self.step_count = 0
        self.episode_count = 0
        self.loader: Dict[int, Iterator[Tuple[torch.Tensor, ...]]] = {
            args.model_batch_size: self.buffer.make_iter(args.model_batch_size)
        }

    def init_memory(self):
        pass

    def get_batch(self, size):
        if size not in self.loader:
            self.loader[size] = self.buffer.make_iter(size)
        return self.loader[size]

    @torch.no_grad()
    def step(
        self,
        agent: Agent,
        exploration_scheduler: Optional[ExplorationScheduler],
    ):
        action = agent.select_action(
            self.state, eval=False, step=self.step_count, episode=self.episode_count
        )
        if exploration_scheduler is not None:
            action = exploration_scheduler(action)
        next_state, reward, done, info = self.env.step(action)
        timelimit = info["truncated"]
        self.buffer.push(self.state, action, reward, next_state, False, timelimit)
        self.reward_aggregate += reward
        self.state = next_state
        if done:
            self.state = self.env.reset()
            episode_reward = self.reward_aggregate
            self.reward_aggregate = 0
            self.episode_count += 1
            self.step_count = 0
            return True, episode_reward
        self.step_count += 1
        return False, self.reward_aggregate

    @torch.no_grad()
    def eval(self, agent, trajectories):
        rewards = 0
        for _ in range(trajectories):
            episode_reward = 0
            state = self.eval_env.reset()
            done = False
            step = 0

            while not done:
                action = agent.select_action(
                    state=state, eval=True, episode=self.episode_count
                )
                state, reward, done, _ = self.eval_env.step(action)
                episode_reward += reward
                step += 1
            rewards += episode_reward
        return rewards / trajectories

    def save(self, path):
        os.makedirs("data", exist_ok=True)
        self.buffer.save("data")

    def load(self, path):
        self.buffer.load("data")


class DMCEnvBufferWrapper(GymEnvBufferWrapper):
    def __init__(self, env, buffer_size, device, args):
        super().__init__(env, buffer_size, device, args)
        self.buffer.return_done = False


def make_env_storage(env_type, env, eval_env, args):
    if env_type == "gym" or env_type == "dmc":
        return GymEnvBufferWrapper(env, eval_env, args.replay_size, args.device, args)
    elif env_type == "visual_dmc":
        return DrQEnvBufferWrapper(env, eval_env, args.replay_size, args.device, args)
    else:
        raise NotImplementedError("Unknown replay buffer type")
