import math
import numpy as np
import torch
import gymnasium as gym
from collections import namedtuple
from gymnasium.wrappers import FlattenObservation, TransformReward


class ContToDiscreteActWrap(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.neutral_action = (env.action_space.low + env.action_space.high) / 2
        self.dim_base = len(self.neutral_action)
        self.action_space = gym.spaces.Discrete(2 * self.dim_base + 1)

    def action(self, a):
        act = self.neutral_action.copy()
        if a:
            ind = (a - 1) % self.dim_base
            act[ind] = self.env.action_space.low[ind] if a <= self.dim_base else self.env.action_space.high[ind]
        return act


class ContToDiscreteActWrapCartesian(gym.ActionWrapper):
    """Discretise each action dimension into three bins."""

    def __init__(self, env):
        super().__init__(env)
        self.neutral_action = (env.action_space.low + env.action_space.high) / 2
        self.dim_base = len(self.neutral_action)
        self.action_space = gym.spaces.Discrete(3**self.dim_base)

    def action(self, a):
        if isinstance(a, np.ndarray):
            a = a.item()

        assert isinstance(a, (int, np.integer)), f'Expected int-like, got {type(a)}'
        act = self.neutral_action.copy()

        for d in reversed(range(self.dim_base)):
            if a % 3 == 1:
                act[d] = self.env.action_space.low[d]
            elif a % 3 == 2:
                act[d] = self.env.action_space.high[d]

            a = math.floor(a / 3)

        return act


class Sampler:
    # Sample transitions from gymnasium type environments
    def __init__(self, env, device='cpu'):
        self.curr_rollout = []
        self.policy = None
        if not hasattr(env, 'return_queue'):
            env = gym.wrappers.RecordEpisodeStatistics(env)
        self.env = env
        self.Return = namedtuple('Return', 'step value')
        self.max_abs_reward = 0
        self.curr_return = 0
        self.curr_entropy = 0.
        self.total_step = 0
        self.Trans = namedtuple('Trans', 'obs act rwd terminated truncated nobs nact') # order must match yield below
        self.device = device

    def _rollout(self, render=False):
        # Generates SARSA type transitions until episode's end
        obs, _ = self.env.reset()
        act, entr = self.policy(obs)
        done = False
        step = 0
        while not done:
            step += 1
            if render:
                self.env.render()
            nobs, rwd, terminated, truncated, _ = self.env.step(act)
            if np.abs(rwd) > self.max_abs_reward:
                self.max_abs_reward = np.abs(rwd)
            nact, nent = self.policy(nobs)
            yield self.Trans(obs, act, rwd, terminated, truncated, nobs, nact), entr, step
            obs = nobs
            act = nact
            entr = nent
            done = terminated or truncated

    def rollouts(self, policy, min_trans, max_trans, render=False):
        # Keep generating full trajectories until min_trans transitions are collected.
        # Specifying max_trans < inf can stop data collection before trajectory's end.
        # If min_trans = max_trans, will collect exactly min_trans transitions.
        assert (min_trans <= max_trans)
        returns = []
        entropies = []
        self.policy = policy
        all_trans = []
        episode_lengths = []
        # Generating transitions and computing returns
        while len(all_trans) < min_trans:
            for trans, entr, step in self.curr_rollout:
                all_trans.append(trans)
                self.curr_return += trans.rwd
                self.curr_entropy += entr
                self.total_step += 1
                if trans.truncated or trans.terminated:
                    # returns.append(self.Return(self.total_step, self.curr_return))
                    returns.append(self.Return(self.total_step, self.curr_return))
                    entropies.append(self.curr_entropy)
                    self.curr_return = 0
                    self.curr_entropy = 0
                    episode_lengths.append(step)
                if len(all_trans) >= max_trans:
                    episode_lengths.append(step)
                    break
            if not len(all_trans) >= max_trans:
                self.curr_rollout = self._rollout(render)

        # Saving into a dictionary of 2D torch.FloatTensor
        paths = {}
        for key in set(self.Trans._fields):
            paths[key] = torch.tensor(np.asarray([getattr(t, key) for t in all_trans]), device=self.device, dtype=torch.float)
            if paths[key].ndim == 1:
                paths[key] = paths[key][:, None]
        return paths, returns, entropies, episode_lengths


class ReplayMemory:
    def __init__(self, max_size, s_dim, device):
        self.device = device
        self.max_size = max_size
        self.size = 0
        self.write_idx = 0
        self.ReplayMemorySamples = namedtuple('ReplayMemorySamples', ['obs', 'act', 'rwd',
                                                                      'terminated', 'nobs', 'nact'])
        self.repmem = self.ReplayMemorySamples(obs=torch.zeros(max_size, s_dim, device=self.device),
                                               act=torch.zeros(max_size, 1, device=self.device).long(),
                                               rwd=torch.zeros(max_size, 1, device=self.device),
                                               terminated=torch.zeros(max_size, 1, device=self.device),
                                               nobs=torch.zeros(max_size, s_dim, device=self.device),
                                               nact=torch.zeros(max_size, 1, device=self.device).long())

    def add_trans(self, trans):
        add_len = len(trans['rwd'])
        overflow = add_len + self.write_idx > self.max_size
        len_first_copy = self.max_size - self.write_idx
        data = self.repmem._asdict()
        for k in data:
            if k == 'act' or k == 'nact':
                v = trans[k].long()
            else:
                v = trans[k]
            if overflow:
                data[k][self.write_idx:] = v[:len_first_copy]
                data[k][0:add_len-len_first_copy] = v[len_first_copy:]
            else:
                data[k][self.write_idx:self.write_idx + add_len] = v
        self.write_idx = (self.write_idx + add_len) % self.max_size
        self.size = min(self.size + add_len, self.max_size)

    def get(self, key):
        return getattr(self.repmem, key)[:self.size]

    def sample_with_idxs(self, batch_size, device=None):
        idxs = np.random.choice(self.size, batch_size, replace=False)
        if device is None:
            return self.ReplayMemorySamples(obs=self.repmem.obs[idxs], act=self.repmem.act[idxs],
                                            rwd=self.repmem.rwd[idxs], terminated=self.repmem.terminated[idxs],
                                            nobs=self.repmem.nobs[idxs], nact=self.repmem.nact[idxs]), idxs

        else:
            return self.ReplayMemorySamples(obs=self.repmem.obs[idxs].to(device=device),
                                            act=self.repmem.act[idxs].to(device=device),
                                            rwd=self.repmem.rwd[idxs].to(device=device),
                                            terminated=self.repmem.terminated[idxs].to(device=device),
                                            nobs=self.repmem.nobs[idxs].to(device=device),
                                            nact=self.repmem.nact[idxs].to(device=device)), idxs

    def sample(self, batch_size, device=None):
        return self.sample_with_idxs(batch_size, device)[0]


def make_env(
    env_name,
    reward_scale=1.0,
    env_constructor=gym.make,
    discrete_act_wrapper=ContToDiscreteActWrap,
    env_render_mode=None,
    env_eval_render_mode=None,
):
    env = gym.make(env_name, render_mode=env_render_mode)
    env_eval = env_constructor(env_name, render_mode=env_eval_render_mode)

    if env_name.startswith('MinAtar'):
        env = FlattenObservation(env)
        env_eval = FlattenObservation(env_eval)
    elif env_name in ['Hopper-v4', 'Ant-v4', 'Walker2d-v4', 'HalfCheetah-v4', 'Humanoid-v4']:
        env = discrete_act_wrapper(env)
        env_eval = discrete_act_wrapper(env_eval)

        env = TransformReward(env, lambda r: reward_scale * r)

    return env, env_eval
