import itertools

import gym
from einops import rearrange
from tqdm.auto import tqdm

from collector.replay_buffer.episode_replay import TrajectoryReplayBuffer, Transition, Trajectory
from envs.wrappers.episode_count_wrapper import EpisodeCountWrapper
from envs.wrappers.tensor_wrapper import TensorWrapper
from policy.base import BasePolicy
from policy.random_policy import RandomPolicy


def record_episode(env: gym.Env, policy: BasePolicy, explore: bool):
    # if explore:
    #     assert isinstance(env, EpisodeCountWrapper)
    #     policy.train(mode=policy.explore_mode)  # depends on policy
    # else:
    #     policy.eval()

    trajectory = Trajectory()

    env = TensorWrapper(env)
    obs = env.reset()
    h = None
    done = False

    while not done:
        obs_in = rearrange(obs, '... -> 1 1 ...')  # add batch and time dim

        if explore:
            act, h, act_logits, state_value, penalties = policy.explore(obs_in, h)  # action from exploration policy
        else:  # validation
            act, h, act_logits, state_value, penalties = policy.greedy(obs_in, h)  # action from greedy policy

        act = rearrange(act, '1 1 -> ')  # remove batch and time dim
        act_logits = rearrange(act_logits, '1 1 a -> a')  # remove batch and time dim
        penalties = rearrange(penalties, '1 1 p -> p')  # remove batch and time dim

        if state_value is not None:
            state_value = rearrange(state_value, '1 1 1 -> ')  # remove batch and time dim

        next_obs, reward, done, info = env.step(act.item())  # apply action to env

        if done and info["done_reason"] != 'timeout':
            next_obs = None

        trajectory.append(Transition(obs, act, reward, next_obs, act_logits, state_value, penalties))  # store transition
        obs = next_obs  # move to next obs

    trajectory.total_reward = sum(step.reward.item() for step in trajectory)
    trajectory.done_reason = info["done_reason"]  # noqa

    return trajectory


class OnPolicyCollectionStrategy:
    def __init__(self, memory: TrajectoryReplayBuffer):
        self.memory = memory

    def collect(self, env: EpisodeCountWrapper, policy: BasePolicy) -> TrajectoryReplayBuffer:
        trajectory = record_episode(env, policy, explore=True)
        self.memory.append(trajectory)
        return self.memory


class OffPolicyCollectionStrategy:
    def __init__(self, memory: TrajectoryReplayBuffer):
        self.memory = memory

    def collect_trajectories(
            self,
            env: EpisodeCountWrapper,
            policy: BasePolicy,
            n_total: int
    ):
        policy.eval()
        # print("Collecting Trajectories (random)... ", end="")

        for i in range(n_total):
            trajectory = record_episode(env, policy, explore=True)
            self.memory.append(trajectory)

        # print("done!", flush=True)

    def collect_trajectories_count(
            self,
            env: EpisodeCountWrapper,
            policy: BasePolicy,
            n_correct: int = 10,
            n_wrong: int = 10,
            n_timeout: int = 0
    ):
        tqdm.write(f"Collecting Trajectories ({n_correct}/{n_wrong}/{n_timeout})... ")

        n_total = n_correct + n_wrong + n_timeout

        correct = 0
        wrong = 0
        timeout = 0

        for i in itertools.count(0):
            trajectory = record_episode(env, policy, explore=True)

            if trajectory.done_reason == "correct" and correct <= n_correct:
                self.memory.append(trajectory)
                correct += 1
            if trajectory.done_reason == "wrong" and wrong <= n_wrong:
                self.memory.append(trajectory)
                wrong += 1
            if trajectory.done_reason == "timeout" and timeout <= n_timeout:
                self.memory.append(trajectory)
                timeout += 1

            if correct + wrong + timeout == n_total:
                break
            if i >= 10_000:
                raise RuntimeError("exceeded max number of rollouts")

        # tqdm.write("done!")

    def collect(self, env: EpisodeCountWrapper, policy: BasePolicy) -> TrajectoryReplayBuffer:

        # self.collect_trajectories_count(env, policy, 10, 10, 0)
        self.collect_trajectories(env, policy, 20)

        # random_policy = RandomPolicy(env.action_space)
        # self.collect_trajectories(env, random_policy, n_total=5)  # add 5 random

        return self.memory
