import itertools
import numpy as np

import torch
import torch.nn.functional as F
from torch.optim import Adam

from algorithms.sac import SAC
from common.buffers import ReplayBuffer
from common.utils import to_torch
from models.dynamics import EnsembleTransitionRewardModel

# https://github.com/Xingyu-Lin/mbpo_pytorch/


class EnvSampler:
    def __init__(self, env):
        self.env = env
        self.curr_obs = None

    def reset(self):
        self.curr_obs = None

    def sample(self, agent, evaluate=False):
        if self.curr_obs is None:
            self.curr_obs = self.env.reset()

        curr_obs = self.curr_obs
        action = agent.select_action(self.curr_obs, evaluate)
        next_obs, reward, done, info = self.env.step(action)
        self.curr_obs = None if done else next_obs
        return curr_obs, action, next_obs, reward, done, info


class PredictEnv:
    def __init__(self, model, env_name):
        self.model = model
        self.env_name = env_name

    def _termination_fn(self, obs, act, next_obs):
        if self.env_name == "Hopper-v2":
            assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
            height, angle = next_obs[:, 0], next_obs[:, 1]
            not_done = (
                np.isfinite(next_obs).all(axis=-1)
                * np.abs(next_obs[:, 1:] < 100).all(axis=-1)
                * (height > 0.7)
                * (np.abs(angle) < 0.2)
            )
            done = ~not_done
            done = done[:, None]
            return done
        elif self.env_name == "Walker2d-v2":
            assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
            height, angle = next_obs[:, 0], next_obs[:, 1]
            not_done = (height > 0.8) * (height < 2.0) * (angle > -1.0) * (angle < 1.0)
            done = ~not_done
            done = done[:, None]
            return done
        else:
            batch_size = obs.shape[0]
            return np.zeros((batch_size, 1), bool)

    def step(self, obs, act, deterministic=False):
        next_obs, rew = self.model.predict(obs, act, deterministic)
        done = self._termination_fn(obs, act, next_obs)
        return next_obs, rew, done, {}


class MBPO(SAC):
    def __init__(self, config, env, logger):
        super().__init__(config, env, logger)
        self.env_sampler = EnvSampler(env)

        # Model buffer
        obs_shape = self.env.observation_space.shape
        act_shape = self.env.action_space.shape
        model_buffer_size = self.compute_model_buffer_size(1)
        self.model_buffer = ReplayBuffer(model_buffer_size, obs_shape, act_shape)

        # Dynamic model
        hidden_dims = [config.model_hidden_size for _ in range(4)]
        self.dynamics = EnsembleTransitionRewardModel(
            env.observation_space.shape,
            env.action_space.shape,
            hidden_dims,
            config.ensemble_size,
            normalize=config.normalize,
        )
        self.dynamics_optim = Adam(
            self.dynamics.parameters(),
            lr=config.model_lr,
            weight_decay=config.model_wd,
        )
        self.predict_env = PredictEnv(self.dynamics, config.env_id)

    def compute_rollout_length(self, epoch_step):
        min_ep, max_ep = self.c.rollout_min_epoch, self.c.rollout_max_epoch
        min_len, max_len = self.c.rollout_min_length, self.c.rollout_max_length
        rollout_length = min_len + (epoch_step - min_ep) / (max_ep - min_ep) * (
            max_len - min_len
        )
        rollout_length = min(max(rollout_length, min_len), max_len)
        return int(rollout_length)

    def compute_model_buffer_size(self, rollout_length):
        rollouts_per_epoch = (
            self.c.rollout_batch_size * self.c.epoch_length / self.c.model_train_freq
        )
        model_steps_per_epoch = int(rollout_length * rollouts_per_epoch)
        model_buffer_size = self.c.model_retain_epochs * model_steps_per_epoch
        return model_buffer_size

    def resize_buffer(self, buffer, new_buffer_size):
        # Create and populate new buffer
        obs_shape = buffer.observations.shape[1:]
        act_shape = buffer.actions.shape[1:]
        new_buffer = ReplayBuffer(new_buffer_size, obs_shape, act_shape)
        new_buffer.push_batch(*buffer.sample(len(buffer), replace=False))
        return new_buffer

    def collect_env_steps(self, env_buffer, num_steps):
        for _ in range(num_steps):
            obs, act, next_obs, rew, done, info = self.env_sampler.sample(self)
            # Ignore done if it comes from truncation
            real_done = 0 if info.get("TimeLimit.truncated", False) else float(done)
            env_buffer.push(obs, act, rew, next_obs, real_done)

    def collect_model_rollouts(self, env_buffer, model_buffer, rollout_length):
        obs = env_buffer.sample(self.c.rollout_batch_size)[0]
        for _ in range(rollout_length):
            act = self.select_action(obs)
            next_obs, rew, done, _ = self.predict_env.step(obs, act)
            model_buffer.push_batch(obs, act, rew, next_obs, done)
            # Remove terminated trajectories
            nonterm_mask = ~done.squeeze(-1)
            if nonterm_mask.sum() == 0:
                break
            obs = next_obs[nonterm_mask]

    def train_dynamics(self, env_buffer):
        for _ in range(self.c.model_train_epochs):
            for obs, act, rew, next_obs, _ in env_buffer.iterate(
                self.c.model_batch_size
            ):
                obs = to_torch(obs)
                act = to_torch(act)
                rew = to_torch(rew)
                next_obs = to_torch(next_obs)
                dyn_loss = self.dynamics.compute_loss(obs, act, rew, next_obs)

                self.dynamics_optim.zero_grad()
                dyn_loss.backward()
                self.dynamics_optim.step()
                self.logger.record("loss/dynamics", dyn_loss.item())

    def train_policy(self, env_buffer, model_buffer):
        for updates in range(self.c.num_train_repeats):
            env_batch_size = int(self.c.batch_size * self.c.real_ratio)
            model_batch_size = self.c.batch_size - env_batch_size
            # Sample real transitions from environment buffer
            env_batch = env_buffer.sample(int(env_batch_size))
            if model_batch_size > 0 and len(model_buffer) > 0:
                # Sample imaginary transitions from model buffer
                model_batch = model_buffer.sample(int(model_batch_size))
                batch = tuple(
                    np.concatenate((env_data, model_data), axis=0)
                    for env_data, model_data in zip(env_batch, model_batch)
                )
            else:
                batch = env_batch
            (
                critic_1_loss,
                critic_2_loss,
                policy_loss,
                ent_loss,
            ) = self.update_parameters(*batch, updates)
            self.logger.record("loss/critic_1", critic_1_loss)
            self.logger.record("loss/critic_2", critic_2_loss)
            self.logger.record("loss/policy", policy_loss)
            self.logger.record("loss/entropy_loss", ent_loss)
            self.logger.record("entropy/alpha", self.alpha.item())
        return self.c.num_train_repeats

    def train(self):
        # Initial exploration
        self.collect_env_steps(self.buffer, self.c.init_exploration_steps)

        rollout_length = 1
        total_step = 0
        for epoch_step in range(self.c.num_epochs):
            start_step = total_step
            policy_step = 0
            for i in itertools.count():
                curr_step = total_step - start_step
                if (
                    curr_step >= self.c.epoch_length
                    and len(self.buffer) > self.c.min_buffer_size
                ):
                    break

                if curr_step % self.c.model_train_freq == 0 and self.c.real_ratio < 1.0:
                    # Train dynamics model
                    self.train_dynamics(self.buffer)

                    # Resize model buffer if necessary
                    new_rollout_length = self.compute_rollout_length(epoch_step)
                    if rollout_length != new_rollout_length:
                        rollout_length = new_rollout_length
                        model_buffer_size = self.compute_model_buffer_size(
                            rollout_length
                        )
                        self.model_buffer = self.resize_buffer(
                            self.model_buffer, model_buffer_size
                        )

                    # Collect model rollouts
                    self.collect_model_rollouts(
                        self.buffer, self.model_buffer, rollout_length
                    )

                # Collect environment step
                obs, act, next_obs, rew, done, info = self.env_sampler.sample(self)
                real_done = 0 if info.get("TimeLimit.truncated", False) else float(done)
                self.buffer.push(obs, act, rew, next_obs, real_done)
                total_step += 1

                # Train policy
                if (
                    total_step % self.c.policy_train_freq == 0
                    and policy_step <= total_step * self.c.max_train_repeats_per_step
                    and len(self.buffer) > self.c.min_buffer_size
                ):
                    policy_step += self.train_policy(self.buffer, self.model_buffer)

                # Evaluate policy
                if total_step % self.c.epoch_length == 0:
                    self.env_sampler.reset()
                    done = False
                    episode_reward = 0
                    episode_success = 0
                    while not done:
                        _, _, _, rew, done, info = self.env_sampler.sample(self)
                        episode_reward += rew
                        episode_success += info.get("success", 0)
                    self.logger.record("test/return", episode_reward)
                    self.logger.record("test/success", float(episode_success > 0))

            self.logger.record("train/epoch", epoch_step)
            self.logger.record("train/rollout_length", rollout_length)
            self.logger.dump(step=total_step)
