import random

from gym.wrappers.monitoring.video_recorder import VideoRecorder
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm, trange

from .sac import SAC
from .config import BaseConfig, Configurable, Require
from .dynamics import DeterministicDynamicsModel, GaussianDynamicsModel, OracleDynamics
from .env.batch import ProductEnv
from .env.util import env_dims, get_max_episode_steps, get_done
from .log import default_log as log
from .policy import UniformPolicy
from .sampling import SampleBuffer, sample_episode_unbatched, sample_episodes_batched
from .torch_util import Module, torchify, numpyify, update_ema, random_choice, gpu_mem_info, quantile
from .train import set_learning_rate
from .util import pythonic_mean, batch_iterator, batch_map


N_EVAL_TRAJ = 10


class StayinAliveSAC(SAC):
    def __init__(self, config, state_dim, action_dim, horizon, update_terminal_cost, **kwargs):
        super().__init__(config, state_dim, action_dim, **kwargs)
        self.horizon = horizon
        self.update_terminal_cost = update_terminal_cost
        self.terminal_cost = 0.0

    def update_r_bounds(self, r_min, r_max):
        self.r_min, self.r_max = r_min, r_max
        self.q_bounds = (
            r_min / (1. - self.discount),
            r_max / (1. - self.discount)
        )
        if self.update_terminal_cost:
            self.terminal_cost = (r_max - r_min) / self.discount**self.horizon - r_max
        log.message(f'r bounds: [{r_min, r_max}],\tQ bounds: [{self.q_bounds[0]}, {self.q_bounds[1]}],\tC = {self.terminal_cost}')

    @property
    def terminal_value(self):
        return -self.terminal_cost / (1. - self.discount)

    def critic_loss(self, obs, action, next_obs, reward, done):
        reward = reward.clamp(self.r_min, self.r_max)
        target = super().compute_target(next_obs, reward, done)
        if done.any():
            target[done] = self.terminal_value
        return self.critic_loss_given_target(obs, action, target)


class SMBPO(Configurable, Module):
    class Config(BaseConfig):
        sac_config = StayinAliveSAC.Config()
        n_models = 4
        use_gaussian_model = True
        model_hidden_dim = 200
        trunk_layers = 3
        head_hidden_layers = 1
        model_initial_fit_epochs = 100
        model_fit_epochs = 50
        model_start_lr = 3e-4
        model_end_lr = 1e-4
        record_collect = False
        horizon = Require(int)
        update_terminal_cost = True
        buffer_min = 10**4
        buffer_max = 10**6
        warmup_iters = 100
        steps_per_epoch = 1000
        rollout_batch_size = 100
        solver_updates_per_step = 20
        real_fraction = 0.1

    def __init__(self, config, env_factory, data):
        Configurable.__init__(self, config)
        Module.__init__(self)
        self.data = data

        self.real_env = env_factory()
        self.eval_env = ProductEnv([env_factory() for _ in range(N_EVAL_TRAJ)])
        self.state_dim, self.action_dim = env_dims(self.real_env)

        base_env = self.real_env.unwrapped
        self.terminal_function = lambda states: torchify(base_env.check_done(numpyify(states)))
        if hasattr(base_env, 'reward'):
            self.reward_function = lambda states, actions, next_states: \
                torchify(base_env.reward(numpyify(states), numpyify(actions), numpyify(next_states)))
        else:
            self.reward_function = None

        self.solver = StayinAliveSAC(self.sac_config, self.state_dim, self.action_dim,
                                     self.horizon, self.update_terminal_cost)

        if self.use_gaussian_model:
            self.models = nn.ModuleList([
                GaussianDynamicsModel(self.state_dim, self.action_dim, self.model_hidden_dim, self.trunk_layers, self.head_hidden_layers)
                for _ in range(self.n_models)
            ])
        else:
            hidden_layers = self.trunk_layers + self.head_hidden_layers
            self.models = nn.ModuleList([
                DeterministicDynamicsModel(self.state_dim, self.action_dim, self.model_hidden_dim, hidden_layers)
                for _ in range(self.n_models)
            ])

        self.real_buffer = SampleBuffer(self.state_dim, self.action_dim, self.buffer_max)
        self.virt_buffer = SampleBuffer(self.state_dim, self.action_dim, self.buffer_max)

        self.uniform_policy = UniformPolicy(self.real_env)

        self.register_buffer('episodes_sampled', torch.tensor(0))
        self.register_buffer('steps_sampled', torch.tensor(0))
        self.register_buffer('n_terminals', torch.tensor(0))
        self.register_buffer('epochs_completed', torch.tensor(0))

        self.recent_critic_losses = []

    @property
    def actor(self):
        return self.solver.actor

    def rollout_and_update(self):
        self.rollout(self.actor)
        for _ in range(self.solver_updates_per_step):
            self.update_solver()

    def collect(self, policy, enable_recorder):
        max_episode_steps = get_max_episode_steps(self.real_env)

        log.message('Collecting...')
        if enable_recorder:
            path = log.dir/f'collect_{self.episodes_sampled.item()}.mp4'
            path = str(path.resolve())
            log.message(f'\tRecording at {path}')
        else:
            path = None
        recorder = VideoRecorder(self.real_env, path=path, enabled=enable_recorder)
        episode = sample_episode_unbatched(self.real_env, policy,
                                           eval=False,
                                           max_steps=max_episode_steps,
                                           recorder=recorder)
        recorder.close()

        states, actions, next_states, rewards, dones = episode.get()
        r = rewards.sum().item()
        l = len(episode)
        log.message(f'\tReturn: {r:.2f}')
        log.message(f'\tLength: {l}')
        self.data.append('collect return', r)
        self.data.append('collect length', l)
        self.episodes_sampled += 1
        self.steps_sampled += l

        if dones.any():
            assert dones.sum() == 1, 'There should only be one done'
            assert dones[-1], 'The done should be in last position'
            if l == max_episode_steps:
                log.message('sus')
            safe = False
            self.n_terminals += 1
        else:
            assert l == max_episode_steps
            safe = True
        self.real_buffer.extend(states, actions, next_states, rewards, dones)
        self.data.append('terminals (collect)', self.n_terminals.item())

        if enable_recorder:
            episode_data = {
                'states': states.cpu(),
                'actions': actions.cpu(),
                'next_states': next_states.cpu(),
                'rewards': rewards.cpu(),
                'dones': dones.cpu()
            }
            torch.save(episode_data, self.episodes_dir/f'episode-{self.episodes_sampled.item()}.pt')

        return safe, l

    def fit_models(self, epochs):
        lr_by_epoch = torch.linspace(self.model_start_lr, self.model_end_lr, steps=epochs)
        log.message(f'Fitting models...')
        for i, model in enumerate(self.models):
            def update_lr(epochs_completed):
                if epochs_completed < epochs:
                    set_learning_rate(model.optimizer, lr_by_epoch[epochs_completed])
            update_lr(0)
            losses = model.fit(self.real_buffer, epochs, post_epoch_callback=update_lr)
            log.message(f'Model {i+1}: {np.array(losses)}')

    def evaluate_models(self):
        states, actions, next_states = self.real_buffer.get('states', 'actions', 'next_states')
        state_std = states.std()
        for i, model in enumerate(self.models):
            with torch.no_grad():
                predicted_states = batch_map(lambda s, a: model.mean(s, a)[0], [states, actions])
            errors = torch.norm((predicted_states - next_states) / state_std, dim=1)
            log.message(f'Model {i+1} error quantiles: {quantile(errors, torch.linspace(0, 1, 11))}')

    def post_collect(self, model_fit_epochs):
        self.fit_models(model_fit_epochs)

        buffer_rewards = self.real_buffer.get('rewards')
        self.solver.update_r_bounds(buffer_rewards.min().item(), buffer_rewards.max().item())

    def rollout(self, policy,
                add_to_virt_buffer=True,    # put the sample in self.virt_buffer?
                initial_states=None,        # initial states to use, otherwise randomly chosen
                solver_updates_per_step=0): # optionally update while stepping
        if initial_states is None:
            initial_states = random_choice(self.real_buffer.get('states'), size=self.rollout_batch_size)
        buffer = SampleBuffer(self.state_dim, self.action_dim, self.rollout_batch_size * self.horizon)
        states = initial_states
        for t in range(self.horizon):
            with torch.no_grad():
                actions = policy.act(states, eval=False)
                next_states, rewards = random.choice(self.models).sample(states, actions)
            if self.reward_function is not None:
                rewards = self.reward_function(states, actions, next_states)
            dones = self.terminal_function(next_states)
            buffer.extend(states, actions, next_states, rewards, dones)
            for _ in range(solver_updates_per_step):
                self.update_solver()
            not_dones = ~dones
            if not_dones.sum() == 0:
                break
            states = next_states[not_dones]

        if add_to_virt_buffer:
            self.virt_buffer.extend(*buffer.get())
        return buffer

    def update_solver(self, update_actor=True):
        solver = self.solver
        n_real = int(self.real_fraction * solver.batch_size)
        real_samples = self.real_buffer.sample(n_real)
        virt_samples = self.virt_buffer.sample(solver.batch_size - n_real)

        combined_samples = [
            torch.cat([real, virt])
            for real, virt in zip(real_samples, virt_samples)
        ]
        critic_loss = solver.critic_loss(*combined_samples)
        solver.critic_optimizer.zero_grad()
        critic_loss.backward()
        solver.critic_optimizer.step()
        update_ema(solver.critic_target, solver.critic, solver.tau)

        # Save loss for later averaging/logging
        self.recent_critic_losses.append(critic_loss.detach().item())

        if update_actor:
            solver.update_actor_and_alpha(combined_samples[0])

    def setup(self):
        self.episodes_dir = log.dir/'episodes'
        self.episodes_dir.mkdir()

        log.message(f'Collecting initial data')
        while len(self.real_buffer) < self.buffer_min:
            self.collect(self.uniform_policy, enable_recorder=False)
        self.post_collect(self.model_initial_fit_epochs)

        log.message(f'Collecting initial virtual data')
        while len(self.virt_buffer) < self.buffer_min:
            self.rollout(self.uniform_policy)

        log.message('Warming up critic')
        for _ in trange(self.buffer_min):
            self.update_solver(update_actor=False)
        for _ in trange(self.warmup_iters):
            self.rollout_and_update()

    def epoch(self):
        expected_samples = self.buffer_min + (self.epochs_completed + 1) * self.steps_per_epoch
        while self.steps_sampled < expected_samples:
            safe, l = self.collect(self.actor, enable_recorder=self.record_collect)
            log.message('Safe! :)' if safe else 'Not safe :(')
            self.post_collect(self.model_fit_epochs)
            for _ in trange(l):
                self.rollout_and_update()
            self.log_statistics()
        self.epochs_completed += 1

    def log_statistics(self):
        self.evaluate_models()

        avg_critic_loss = pythonic_mean(self.recent_critic_losses)
        log.message(f'Average recent critic loss: {avg_critic_loss}')
        self.data.append('critic loss', avg_critic_loss)
        self.recent_critic_losses.clear()

        log.message('Buffer sizes:')
        log.message(f'\tReal: {len(self.real_buffer)}')
        log.message(f'\tVirtual: {len(self.virt_buffer)}')

        real_states, real_actions, real_dones = self.real_buffer.get('states', 'actions', 'dones')
        virt_states, virt_dones = self.virt_buffer.get('states', 'dones')
        virt_actions = self.actor.act(virt_states, eval=True).detach()
        sa_data = {
            'real (done)': (real_states[real_dones], real_actions[real_dones]),
            'real (~done)': (real_states[~real_dones], real_actions[~real_dones]),
            'virtual (done)': (virt_states[virt_dones], virt_actions[virt_dones]),
            'virtual (~done)': (virt_states[~virt_dones], virt_actions[~virt_dones])
        }
        for which, (states, actions) in sa_data.items():
            if len(states) == 0:
                mean_q = None
            else:
                with torch.no_grad():
                    qs = batch_map(lambda s, a: self.solver.critic_value(s, a, target=False), [states, actions])
                    mean_q = qs.mean()
            log.message(f'Average Q {which}: {mean_q}')
            self.data.append(f'Average Q {which}', mean_q)

        if torch.cuda.is_available():
            log.message(f'GPU memory info: {gpu_mem_info()}')

    def evaluate(self):
        log.message(f'Evaluating after {self.steps_sampled} samples ({self.episodes_sampled} episodes)')
        eval_traj = sample_episodes_batched(self.eval_env, self.solver, N_EVAL_TRAJ, eval=True)

        lengths = [len(traj) for traj in eval_traj]
        length_mean, length_std = float(np.mean(lengths)), float(np.std(lengths))
        self.data.append('eval length', length_mean)
        log.message(f'\tLength: {length_mean} +/- {length_std}')

        returns = [traj.get('rewards').sum().item() for traj in eval_traj]
        return_mean, return_std = float(np.mean(returns)), float(np.std(returns))
        self.data.append('eval return', return_mean)
        log.message(f'\tReturn: {return_mean} +/- {return_std}')

        n_terminals = self.n_terminals.item()
        n_episodes = self.episodes_sampled.item()
        frac_terminal = n_terminals / n_episodes
        self.data.append('terminals (epoch)', n_terminals)
        self.data.append('frac episodes terminal', frac_terminal)
        log.message(f'{n_terminals} total terminals (fraction: {frac_terminal})')