from collections import OrderedDict
import gtimer as gt

import numpy as np
import torch

import lfrl.torch.pytorch_util as ptu
from lfrl.torch.pytorch_util import np_to_pytorch_batch
from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer
from lfrl.samplers import model_policy_rollout_torch


def always_one(train_steps):
    return 1


class MBPOTrainer(TorchTrainer):
    def __init__(
            self,
            policy_trainer,
            dynamics_model, # MBPOTrainer not responsible for training this
            discriminator,
            replay_buffer,
            generated_data_buffer,
            num_model_rollouts=400,
            rollout_generation_freq=250,
            num_policy_updates=20,
            rollout_len_func=always_one,
            rollout_batch_size=int(1e3),
            real_data_pct=0.05,
            **kwargs
    ):
        super().__init__()
        
        self.policy_trainer = policy_trainer
        self.policy = policy_trainer.policy
        self.dynamics_model = dynamics_model
        self.replay_buffer = replay_buffer
        self.generated_data_buffer = generated_data_buffer

        self.num_model_rollouts = num_model_rollouts
        self.rollout_generation_freq = rollout_generation_freq
        self.num_policy_updates = num_policy_updates
        self.rollout_len_func = rollout_len_func 
        self.rollout_batch_size = rollout_batch_size
        self.real_data_pct = real_data_pct

        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()

        self.discriminator = discriminator
        self.discriminator_optim = torch.optim.Adam(
            discriminator.parameters(), lr=1e-3,
        )
        
    def train_from_torch(self, batch):
        # we'll ignore batch (slight violation of abstraction)

        """
        Generate data using dynamics model
        """
        if self._n_train_steps_total % self.rollout_generation_freq == 0:
            rollout_len = self.rollout_len_func(self._n_train_steps_total)
            total_samples = min(
                self.rollout_generation_freq * self.num_model_rollouts * rollout_len,
                self.generated_data_buffer.max_replay_buffer_size()
            )

            num_samples = 0
            while num_samples < total_samples:
                batch_samples = min(self.rollout_batch_size, total_samples-num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = model_policy_rollout_torch(
                        self.dynamics_model,
                        self.policy_trainer.policy,
                        start_states,
                        max_path_length=50 # rollout_len,
                    )

                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    num_samples += len(path['observations'])

            # train discriminator
            batch_size, test_batch_size = 256, 4096
            best_loss, loss = int(1e6), 0
            num_epochs_since_last_update = 0
            num_steps = 0
            while num_epochs_since_last_update < 10:
                real_batch = np_to_pytorch_batch(
                    self.replay_buffer.random_batch(batch_size, 0.5, 1))
                fake_batch = np_to_pytorch_batch(
                    self.generated_data_buffer.random_batch(batch_size))

                loss = self.discriminator.get_loss(real_batch, fake_batch)

                self.discriminator_optim.zero_grad()
                loss.backward()
                self.discriminator_optim.step()

                if (best_loss - loss) / abs(best_loss) > 0.01:
                    best_loss = loss
                    num_epochs_since_last_update = 0
                else:
                    num_epochs_since_last_update += 1
                num_steps += 1

            with torch.no_grad():
                real_batch = np_to_pytorch_batch(
                    self.replay_buffer.random_batch(test_batch_size))
                fake_batch = np_to_pytorch_batch(
                    self.generated_data_buffer.random_batch(test_batch_size))
                loss, accuracy = self.discriminator.get_loss(
                    real_batch, fake_batch, return_accuracy=True)
                self.eval_statistics['Discriminator Accuracy'] = accuracy.item()
                self.eval_statistics['Discriminator Gradient Steps'] = num_steps

            num_samples = 0
            while num_samples < total_samples:
                batch_samples = min(self.rollout_batch_size, total_samples-num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = model_policy_rollout_torch(
                        self.dynamics_model,
                        self.policy_trainer.policy,
                        start_states,
                        max_path_length=rollout_len,
                    )

                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    num_samples += len(path['observations'])

            gt.stamp('generating rollouts', unique=False)

        """
        Update policy on both real and generated data
        """

        batch_size = batch['observations'].shape[0]
        n_real_data = int(self.real_data_pct * batch_size)
        n_generated_data = batch_size - n_real_data

        for _ in range(self.num_policy_updates):

            batch = self.replay_buffer.random_batch(n_real_data)
            generated_batch = self.generated_data_buffer.random_batch(
                n_generated_data)

            for k in ('rewards', 'terminals', 'observations',
                      'actions', 'next_observations'):
                batch[k] = np.concatenate((batch[k], generated_batch[k]), axis=0)
                batch[k] = ptu.from_numpy(batch[k])

            self.policy_trainer.train_from_torch(batch)

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['MBPO Rollout Length'] = rollout_len

            returns, terminated = [], []
            for path in paths:
                returns.append(np.sum(path['rewards']))
                terminated.append(np.sum(path['terminals']) > 0)

            average_return = sum(returns) / len(paths)
            terminated_pct = sum(terminated) / len(paths)

            self.eval_statistics['MBPO Rollout Return Average'] = average_return
            self.eval_statistics['MBPO Rollout Return Per Timestep'] = \
                average_return / rollout_len
            self.eval_statistics['MBPO Rollout Terminated Percent'] = terminated_pct

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        self.eval_statistics.update(self.policy_trainer.eval_statistics)
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True
        self.policy_trainer.end_epoch(epoch)

    @property
    def networks(self):
        return self.policy_trainer.networks

    def get_snapshot(self):
        mbpo_snapshot = dict(
            dynamics_model=self.dynamics_model,
            discriminator=self.discriminator,
        )
        mbpo_snapshot.update(self.policy_trainer.get_snapshot())
        return mbpo_snapshot

