import tensorflow as tf
from utils.models import LatentModel, Actor, Critic
from tensorflow.keras.mixed_precision import experimental as prec
from utils.utils_tools import load_dataset, episode_itterator, Capturing
from utils.utils_tools import Adam
from tqdm import tqdm
from algo.LOMPO import Lompo
from algo.ORAAC import Oraac
from algo.COMBO import Combo
from algo.LODAC import Lodac
import os
from utils.utils_tools import preprocess_raw, count_episodes
from utils.utils_tools import flatten, preprocess_latent
from utils.buffer import LatentReplayBuffer
from math import ceil
from tensorflow_probability import distributions as tfd
from utils.wrappers import *
from utils.utils_tools import save_episodes, summarize_episode
import cv2
import cv2.cv2
from pathlib import Path
import functools


class Algo(tf.Module):
    def __init__(self, config):
        self._config = config
        self._latent_model_step = 0
        self._num_actor_critic_train_step = 0
        self._lmbd = config.lmbd
        self._episode_length = config.time_limit / config.action_repeat

        # models
        self._latent = LatentModel(config)
        # build actor/critic
        if config.latent_algo == 'lompo':
            self._actor_critic = Lompo(config)
        elif config.latent_algo == 'oraac':
            self._actor_critic = Oraac(config)
        elif config.latent_algo == 'combo':
            self._actor_critic = Combo(config)
        elif config.latent_algo == 'lodac':
            self._actor_critic = Lodac(config)
        else:
            raise NotImplementedError(config.latent_algo)

        # tensorboard writer
        tf_dir = config['logdir'] / 'tensorboard'
        self._writer = tf.summary.create_file_writer(str(tf_dir), max_queue=1000, flush_millis=20000)
        self._writer.set_as_default()

        # testing environment
        self._env = self._make_env(prefix='test')

        # buffers
        self._actdim = self._env.action_space.shape[0]
        episodes, steps = count_episodes(config.datadir, self._episode_length)
        self.latent_buffer = LatentReplayBuffer(steps,
                                                steps,
                                                self._config.deter + self._config.stoch,
                                                self._actdim)

        # dataset
        self._float = prec.global_policy().compute_dtype
        self._dataset = iter(load_dataset(config.datadir, self._config))
        self._episode_itterator = episode_itterator(config['datadir'])

        # optimizer
        Optimizer = functools.partial(Adam,
                                      wd=self._config['weight_decay'],
                                      clip=self._config['grad_clip'])
        latent_modules = [self._latent.latent_dynamic, self._latent.encoder, self._latent.decoder, self._latent.reward]
        self._latent_model_opt = Optimizer('latent_model', latent_modules, self._config.latent_model_lr)

        self._variables_initializations()

    def _variables_initializations(self):
        data = next(self._dataset)
        self._latent_model_train_step(data, prefix='eval')
        self.evaluate()

    def save_latent_model(self, filename):
        if not os.path.exists(filename):
            os.mkdir(filename)
        self._latent.save(filename)

    def load_latent_model(self, filename):
        self._latent.load(filename)

    def load_agent(self, filename):
        self._actor_critic.load_actor(filename)

    def latent_model_training(self, itters):
        bar = tqdm(range(itters))
        for itter in bar:
            bar.set_description("Latent variable model training.")
            data = next(self._dataset)
            self._latent_model_train_step(data)
            if itter % self._config.save_every == 0:
                self.save_latent_model(self._config.trained_latent_dir / 'latent_training_step_{}'.format(itter))
        self.save_latent_model(self._config.trained_latent_dir / 'final_latent_model')

    def _latent_model_train_step(self, data, prefix='train'):
        with tf.GradientTape() as model_tape:
            embed = self._latent.encoder(data)
            post, prior = self._latent.latent_dynamic.observe(embed, data['action'])
            feat = self._latent.latent_dynamic.get_feat(post)
            image_pred = self._latent.decoder(feat)
            reward_pred = self._latent.reward(feat)
            likes = dict()
            likes['image'] = tf.reduce_mean(tf.boolean_mask(image_pred.log_prob(data['image']), data['mask']))
            likes['reward'] = tf.reduce_mean(tf.boolean_mask(reward_pred.log_prob(data['reward']), data['mask']))

            for key in prior.keys():
                prior[key] = tf.boolean_mask(prior[key], data['mask'])
                post[key] = tf.boolean_mask(post[key], data['mask'])

            prior_dist = self._latent.latent_dynamic.get_dist(prior)
            post_dist = self._latent.latent_dynamic.get_dist(post)
            div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist))
            model_loss = self._config['kl_scale'] * div - sum(likes.values())

        if prefix == 'train':
            model_norm = self._latent_model_opt(model_tape, model_loss)
            self._latent_model_step += 1

        if self._latent_model_step % self._config['log_every'] == 0:
            model_summaries = dict()
            model_summaries['latent_model_train/KL Divergence'] = tf.reduce_mean(div)
            model_summaries['latent_model_train/image_recon'] = tf.reduce_mean(likes['image'])
            model_summaries['latent_model_train/reward_recon'] = tf.reduce_mean(likes['reward'])
            model_summaries['latent_model_train/model_loss'] = tf.reduce_mean(model_loss)
            if prefix == 'train':
                model_summaries['latent_model_train/model_norm'] = tf.reduce_mean(model_norm)

            self._write_summaries(model_summaries, self._latent_model_step)

    def train(self):
        # Generate latent synthetic data
        self.generate_latent_synthetic_data(num_data=self._config.num_first_synthetic_data, print_process=True)
        if self._config.latent_algo == 'oraac':
            if self._config.load_imit_actor:
                self._actor_critic.load_imit_actor(self._config.logdir / 'best_imit')
                print("Evaluation of the loaded imitation actor.")
                self.evaluate(evaluate_imitation_=True, episodes=10)
            else:
                print("First, we train the imitation agent.")
                for step in range(self._config.num_imit_training_steps//10000):
                    self._imit_actor_training(num_trainings_steps=10000, print_process=True)
                    if step > 2:
                        # print_step = step + 10
                        print('#' * 50)
                        print("After {} training steps".format((step+1) * 10000))
                        self.evaluate(evaluate_imitation_=True, episodes=10)
                        print('#' * 50)
                        print('\n')
                        self._actor_critic.save_imit_actor(
                            self._config.logdir / '_imit_{}_training_steps'.format((step+1) * 10000))

        print("We begin the training of the actor")
        best_cvar = 0
        for step in range(self._config.num_actor_critic_loop):
            if step % self._config.num_evaluate == 0:
                num_train = step * self._config.num_actor_critic_training_per_loop
                if num_train > 1:
                    print('#' * 50)
                    print("After {} training steps.".format(num_train))
                    current_cvar = self.evaluate(episodes=10)
                    if current_cvar > best_cvar:
                        self._actor_critic.save_actor_critic(self._config.logdir / 'best_agent')
                        best_cvar = current_cvar
                    print('#' * 50)
                    print('\n')
                    name = 'agent_' + str(num_train)
                    self._actor_critic.save_actor_critic(self._config.logdir / name)

            self.actor_critic_training(self._config.num_actor_critic_training_per_loop)
            self.add_latent_data(num_episodes=1)

    def add_latent_data(self, num_episodes=1):
        self.process_data_to_latent(num_episodes=num_episodes)
        self._generate_synthetic_step(next(self._dataset))

    def actor_critic_training(self, num_trainings_steps, print_process=False):
        if print_process:
            num_steps = tqdm(range(num_trainings_steps))
        else:
            num_steps = range(num_trainings_steps)
        for _ in num_steps:
            if print_process:
                num_steps.set_description("Actor critic training.")
            if self._config.latent_algo == 'combo':
                real_data = preprocess_latent(self.latent_buffer.sample_real(self._config.actor_batch_size//2))
                synthetic_data = preprocess_latent(self.latent_buffer.sample_synthetic(self._config.actor_batch_size//2))
                self._actor_critic.actor_critic_train_step(real_data, synthetic_data)
            else:
                data = preprocess_latent(self.latent_buffer.sample(self._config.actor_batch_size))
                self._actor_critic.actor_critic_train_step(data)

    def _imit_actor_training(self, num_trainings_steps, print_process=False):
        """
        For ORAAC
        """
        if print_process:
            num_steps = tqdm(range(num_trainings_steps))
        else:
            num_steps = range(num_trainings_steps)
        for step in num_steps:
            if print_process:
                num_steps.set_description("Imitation actor training.")
            data = preprocess_latent(self.latent_buffer.sample_real(self._config.imit_batch_size))  # Train only on real data
            self._actor_critic.imit_actor_training_step(data)

    def generate_latent_synthetic_data(self, num_data, print_process=False):
        num_add_per_step = self._config.latent_batch_size*self._config.latent_batch_length*self._config.horizon
        num_itters = ceil(num_data / num_add_per_step)
        if print_process:
            num_it = tqdm(range(num_itters))
        else:
            num_it = range(num_itters)
        for _ in num_it:
            if print_process:
                num_it.set_description("Generating latent synthetic data.")
            data = next(self._dataset)
            self._generate_synthetic_step(data)

    def _generate_synthetic_step(self, data):
        embed = self._latent.encoder(data)
        post, prior = self._latent.latent_dynamic.observe(embed, data['action'])
        for key in post.keys():
            post[key] = tf.boolean_mask(post[key], data['mask'])
        start = post
        policy = lambda state: tf.stop_gradient(
            self._exploration(self._actor_critic.select_action(self._latent.latent_dynamic.get_feat(state)))
        )

        obs = [[] for _ in tf.nest.flatten(start)]
        next_obs = [[] for _ in tf.nest.flatten(start)]
        actions = []
        full_posts = [[[] for _ in tf.nest.flatten(start)] for _ in range(self._config.num_models)]
        prev = start

        for index in range(self._config.horizon):
            [o.append(l) for o, l in zip(obs, tf.nest.flatten(prev))]
            a = policy(prev)
            actions.append(a)
            for i in range(self._config.num_models):
                p = self._latent.latent_dynamic.img_step(prev, a, k=i)
                [o.append(l) for o, l in zip(full_posts[i], tf.nest.flatten(p))]
            prev = self._latent.latent_dynamic.img_step(prev, a, k=np.random.choice(self._config.num_models, 1)[0])
            [o.append(l) for o, l in zip(next_obs, tf.nest.flatten(prev))]

            obs = self._latent.latent_dynamic.get_feat(tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in obs]))
            stoch = tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in next_obs])['stoch']
            next_obs = self._latent.latent_dynamic.get_feat(tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in next_obs]))
            actions = tf.stack(actions, 0)
            rewards = self._latent.reward(next_obs).mode()
            dones = tf.zeros_like(rewards)

            dists = [self._latent.latent_dynamic.get_dist(
                tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in full_posts[i]]))
                for i in range(self._config.num_models)]

            # Compute penalty : only for lompo
            if self._config.latent_algo == 'lompo':
                log_prob_vars = tf.math.reduce_std(
                    tf.stack([d.log_prob(stoch) for d in dists], 0),
                    axis=0)
                rewards = rewards - self._lmbd * log_prob_vars

            self.latent_buffer.add_samples(flatten(obs).numpy(),
                                           flatten(actions).numpy(),
                                           flatten(next_obs).numpy(),
                                           flatten(rewards).numpy(),
                                           flatten(dones),
                                           sample_type='latent')

            obs = [[] for _ in tf.nest.flatten(start)]
            next_obs = [[] for _ in tf.nest.flatten(start)]
            actions = []
            full_posts = [[[] for _ in tf.nest.flatten(start)] for _ in range(self._config.num_models)]

            for key in prev.keys():
                prev[key] = tf.boolean_mask(prev[key], flatten(1.0 - dones))

    def _exploration(self, action):
        """
        Add noise in the action, according to config.noise_type and config.amount_action_noise
        """
        amount = self._config.amount_action_noise
        if self._config.noise_type == 'additive_gaussian':
            return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
        if self._config.noise_type == 'completely_random':
            return tf.random.uniform(action.shape, -1, 1)
        raise NotImplementedError(self._config.noise_type)

    def _write_summaries(self, metrics, step=None):
        step = int(step)
        metrics = [(k, float(v)) for k, v in metrics.items()]
        with self._writer.as_default():
            tf.summary.experimental.set_step(step)
            [tf.summary.scalar(k, m, step=step) for k, m in metrics]
        print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics))
        self._writer.flush()

    def evaluate_expert_agent(self, num_episodes=1):
        for step in range(num_episodes):
            filename = next(self._episode_itterator)
            try:
                with filename.open('rb') as f:
                    episode = np.load(f)
                    episode = {k: episode[k] for k in episode.keys()}
            except Exception as e:
                print(f'Could not load episode: {e}')
                continue
            print("The cumulative reward is ", np.sum(episode['reward']))

    def process_data_to_latent(self, print_process=False, num_episodes=None):
        if num_episodes is None:
            print_process = True
            num_episodes, _ = count_episodes(self._config.datadir, self._episode_length)
        if print_process:
            bar = tqdm(range(num_episodes))
        else:
            bar = range(num_episodes)

        for _ in bar:
            if print_process:
                bar.set_description("Transforming the data into latent.")
            filename = next(self._episode_itterator)
            try:
                with filename.open('rb') as f:
                    episode = np.load(f)
                    episode = {k: episode[k] for k in episode.keys()}
            except Exception as e:
                print(f'Could not load episode: {e}')
                continue

            obs = preprocess_raw(episode, self._config)

            obs['terminal'] = tf.zeros_like(obs['reward'])
            with tf.GradientTape(watch_accessed_variables=False) as _:
                embed = self._latent.encoder(obs)
                post, prior = self._latent.latent_dynamic.observe(tf.expand_dims(embed, 0),
                                                                  tf.expand_dims(obs['action'], 0))
                feat = flatten(self._latent.latent_dynamic.get_feat(post))
                self.latent_buffer.add_samples(feat.numpy()[:-1],
                                               obs['action'].numpy()[1:],
                                               feat.numpy()[1:],
                                               obs['reward'].numpy()[1:],
                                               obs['terminal'].numpy()[1:],
                                               sample_type='real')

    def _make_env(self, prefix, store=False):
        env = DeepMindControl(self._config.task)
        env = ActionRepeat(env, self._config.action_repeat, self._config)
        env = NormalizeActions(env)
        env = TimeLimit(env, self._config.time_limit / self._config.action_repeat)
        env = StochasticEnv(env, risky=self._config.risky, penal=self._config.penal, prob=self._config.prob, max_reward=self._config.max_reward)
        callbacks = []
        if store:
            callbacks.append(lambda ep: save_episodes(self._config.datadir, [ep]))
        if prefix == 'test':
            callbacks.append(lambda ep: summarize_episode(ep, self._config, self._writer, prefix))
        env = Collect(env, callbacks, self._config.precision)
        env = RewardObs(env)
        return env

    def action_from_raw(self, obs, done, latent_state=None, imitation_actor=False):
        if latent_state is not None and done:
            mask = tf.cast(1 - done, self._float)[:, None]
            latent_state = tf.nest.map_structure(lambda x: x * mask, latent_state)
        action, state = self.action(obs, latent_state, imitation_actor)
        return action, state

    def action(self, obs, state, imitation=False):
        if state is None:
            latent = self._latent.latent_dynamic.initial(len(obs['image']))
            action = tf.zeros((len(obs['image']), self._actdim), self._float)
        else:
            latent, action = state

        embed = self._latent.encoder(preprocess_raw(obs, self._config))
        latent, _ = self._latent.latent_dynamic.obs_step(latent, action, embed)
        feat = self._latent.latent_dynamic.get_feat(latent)
        if imitation:
            action = self._actor_critic._actor._imit_actor(feat)
        else:
            action = self._actor_critic.select_action(feat)
        state = (latent, action)
        return action, state

    def evaluate_dataset(self, max_episodes=1):
        tmp = 0
        all_reward = []
        files = Path(self._config.datadir).glob('*')
        for file in files:
            episode = np.load(file)
            print('The cumulative rewards is ', np.sum(episode['reward']))
            all_reward.append(np.sum(episode['reward']))
            tmp += 1
            if tmp > max_episodes:
                break
        print('Which gives a returned mean of ', np.round(np.mean(all_reward), 2))

    def random_policy(self, num_episodes=1):
        total_rewards = []
        for _ in range(num_episodes):
            done = False
            state = self._env.reset()
            cum_reward = state['reward']
            while not done:
                action = np.random.uniform(self._env.action_space.low, self._env.action_space.high, self._env.action_space.shape)
                with Capturing() as output:
                    state, reward, done, _, _ = self._env.step(action)
                cum_reward += reward
            total_rewards.append(cum_reward)
            print("With a random policy, we get a return of {}".format(np.round(cum_reward, 2)))
        print('Which gives a total mean return of {}'.format(np.round(np.mean(total_rewards), 2)))

    def evaluate(self, episodes=1, step_=0, state_=None, evaluate_imitation_=False):
        reward = []
        with Capturing() as output:
            self.evaluate_step(steps=step_, state=state_, imitation=evaluate_imitation_)
        output = np.asarray(output, dtype=float)
        reward.append(output[0])
        if episodes > 1:
            for index in range(episodes-1):
                with Capturing() as output:
                    self.evaluate_step(steps=step_, state=state_, imitation=evaluate_imitation_)
                output = np.asarray(output, dtype=float)
                reward.append(output[0])
        sum_return = 0
        for index in range(len(reward)):
            print("Test episode : return {}".format(reward[index]))
            sum_return += np.fromstring(reward[index])
        if episodes > 1:
            # compute cvar :
            num_data = round(self._config.alpha_cvar*len(reward))
            sorted_rewards = sorted(reward)
            cvar = np.mean(sorted_rewards[:num_data])

            print("Which gives a returned mean of {}, "
                  "and a cvar of {}".format(round(sum_return[0]/episodes),round(cvar)))

            return cvar

    def save_evaluation_video(self, episodes=1, steps=0):
        step, episode = 0, 0
        done = np.ones(1, np.bool)
        length = np.zeros(1, np.int32)
        obs = [None]
        latent_state = None
        frames = []
        while (steps and step < steps) or (episodes and episode < episodes):
            # Reset if necessary
            if done:
                obs = self._env.reset()

            # actor step
            obs = {k: np.stack([obs[k]]) for k in obs}
            frames.append(obs['image'][0])
            action, latent_state = self.action_from_raw(obs, done, latent_state)
            action = np.squeeze(np.stack(action))

            promises = self._env.step(action)
            # obs, _, done = zip(*[p()[:3] for p in promises])
            obs, _, done = promises[:3]
            # obs = list(obs)
            # done = np.stack(done)
            done = int(done)
            episode += int(done)
            length += 1
            step += 1
            step += (done * length).sum()
            length *= (1 - done)
        images = np.stack(frames)
        name = str(self._config.videos_dir) + '/' + 'actor_' + str(self._num_actor_critic_train_step) + '.mp4'
        out = cv2.VideoWriter(name, cv2.cv2.VideoWriter_fourcc(*'mp4v'), 30, (64, 64))
        for frame in images:
            out.write(frame)
        out.release

    def evaluate_step(self, episodes=1, steps=0, state=None, imitation=False):
        if state is None:
            step, episode = 0, 0
            done = np.ones(1, np.bool)
            length = np.zeros(1, np.int32)
            obs = [None]
            latent_state = None
        else:
            step, episode, done, length, obs, latent_state = state
        while (steps and step < steps) or (episodes and episode < episodes):
            # Reset if necessary
            if done:
                obs = self._env.reset()

            # actor step
            obs = {k: np.stack([obs[k]]) for k in obs}
            action, latent_state = self.action_from_raw(obs, done, latent_state, imitation_actor=imitation)
            action = np.squeeze(np.stack(action))

            promises = self._env.step(action)
            obs, _, done = promises[:3]
            done = int(done)
            episode += int(done)
            length += 1
            step += 1
            step += (done * length).sum()
            length *= (1 - done)


