import torch

from lfrl.data_management.replay_buffers.simple_replay_buffer import SimpleReplayBuffer
from lfrl.models.dynamics_models.probabilistic_ensemble import ProbabilisticEnsemble
from lfrl.policies.base.latent_prior_policy import PriorLatentPolicy
from lfrl.policies.models.gaussian_policy import TanhGaussianPolicy
from lfrl.torch.networks import FlattenMlp
from lfrl.trainers.dads.dads_model import DADSModelTrainer
from lfrl.trainers.dads.skill_dynamics import SkillDynamics
from lfrl.trainers.mbrl.mbrl import MBRLTrainer
from lfrl.trainers.qpg.sac import SACTrainer
import lfrl.torch.pytorch_util as ptu
import lfrl.util.pythonplusplus as ppp


from lfrl.core import logger
from lfrl.samplers.utils.rollout_functions import rollout_with_latent
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def get_config(
        variant,
        expl_env,
        eval_env,
        obs_dim,
        action_dim,
        replay_buffer,
):

    """
    Policy construction
    """

    M = variant['policy_kwargs']['layer_size']
    latent_dim = variant['policy_kwargs']['latent_dim']
    restrict_dim = variant['discriminator_kwargs']['restrict_input_size']

    control_policy = TanhGaussianPolicy(
        obs_dim=obs_dim + latent_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
        restrict_obs_dim=restrict_dim,
    )

    prior = torch.distributions.uniform.Uniform(
        -ptu.ones(latent_dim), ptu.ones(latent_dim),
    )

    policy = PriorLatentPolicy(
        policy=control_policy,
        prior=prior,
        unconditional=True,
    )

    qf1, qf2, target_qf1, target_qf2 = ppp.group_init(
        4,
        FlattenMlp,
        input_size=obs_dim + latent_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )

    """
    Discriminator
    """

    discrim_kwargs = variant['discriminator_kwargs']
    restrict_input_size = discrim_kwargs['restrict_input_size']
    discriminator = SkillDynamics(
        observation_size=obs_dim,
        action_size=action_dim,
        latent_size=latent_dim,
        normalize_observations=True,
        fix_variance=True,
        fc_layer_params=[discrim_kwargs['layer_size']] * discrim_kwargs['num_layers'],
    )

    """
    Policy trainer
    """

    policy_trainer = SACTrainer(
        env=expl_env,
        policy=control_policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['policy_trainer_kwargs'],
    )

    """
    Model-based reinforcement learning (MBRL) dynamics models
    """

    M = variant['mbrl_kwargs']['layer_size']

    dynamics_model = ProbabilisticEnsemble(
        ensemble_size=variant['mbrl_kwargs']['ensemble_size'],
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
        noise_clip=variant['mbrl_kwargs']['noise_clip'],
    )
    model_trainer = MBRLTrainer(
        ensemble=dynamics_model,
        **variant['mbrl_kwargs'],
    )

    rollout_len_schedule = variant['rollout_len_schedule']

    def rollout_len(train_steps):
        """
        rollout_len_schedule: [a, b, len_a, len_b]
        linearly increase length from len_a -> len_b over epochs a -> b
        """
        epoch = train_steps // 1000
        if epoch < rollout_len_schedule[0]:
            return 1
        elif epoch >= rollout_len_schedule[1]:
            return rollout_len_schedule[3]
        else:
            return int(
                (epoch - rollout_len_schedule[0]) /
                (rollout_len_schedule[1] - rollout_len_schedule[0]) *
                (rollout_len_schedule[3] - rollout_len_schedule[2])
            ) + rollout_len_schedule[2]

    """
    Setup of intrinsic control
    """

    latent_buffer = SimpleReplayBuffer(
        variant['generated_replay_buffer_size'],
        obs_dim + latent_dim,
        action_dim,
        dict(),
    )

    trainer = DADSModelTrainer(
        dynamics_model=dynamics_model,
        rollout_len_func=rollout_len,
        control_policy=control_policy,
        discriminator=discriminator,
        replay_buffer=replay_buffer,
        latent_buffer=latent_buffer,
        policy_trainer=policy_trainer,
        **variant['trainer_kwargs']
    )

    """
    Create config dict
    """

    config = dict()
    config.update(dict(
        trainer=trainer,
        model_trainer=model_trainer,
        exploration_policy=policy,
        evaluation_policy=policy,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        replay_buffer=replay_buffer,
        generated_replay_buffer=latent_buffer,
        dynamics_model=dynamics_model,
        prior=prior,
        control_policy=control_policy,
        latent_dim=latent_dim,
        policy_trainer=policy_trainer,
        rollout_len_func=rollout_len,
    ))
    config['algorithm_kwargs'] = variant.get('algorithm_kwargs', dict())
    config['offline_kwargs'] = variant.get('offline_kwargs', dict())

    if variant['env_name'] in ['Gridworld']:

        def plot_graph(epoch):
            fig = plt.figure()
            fig.set_size_inches(4, 4)

            size = 1
            plt.xlim(-size, size)
            plt.ylim(-size, size)

            colors = np.array(sns.color_palette())

            num_lats = 3
            for i in range(num_lats):
                policy.fixed_latent = True
                policy.sample_latent()
                for ep in range(3):
                    path = rollout_with_latent(eval_env, policy,
                                               max_path_length=variant['algorithm_kwargs']['max_path_length'])
                    obs = path['observations']
                    plt.plot(obs[:,0], obs[:,1], linewidth=1, color=tuple(colors[i]), alpha=0.5)
                policy.fixed_latent = False

            logger.savefig('graphs/grid_%d.png' % epoch, fig=fig)
            plt.close()

        config['algorithm_kwargs']['post_epoch_funcs'] = [plot_graph]

    return config
