import torch

from lfrl.data_management.replay_buffers.simple_replay_buffer import SimpleReplayBuffer
from lfrl.policies.base.latent_prior_policy import PriorLatentPolicy
from lfrl.policies.models.gaussian_policy import TanhGaussianPolicy
from lfrl.torch.networks import FlattenMlp
from lfrl.trainers.dads.dads import DADSTrainer
from lfrl.trainers.dads.dads_offpolicy import DADSOffTrainer
from lfrl.trainers.dads.skill_dynamics import SkillDynamics
from lfrl.trainers.qpg.sac import SACTrainer
import lfrl.torch.pytorch_util as ptu
import lfrl.util.pythonplusplus as ppp

from lfrl.core import logger
from lfrl.samplers.utils.rollout_functions import rollout_with_latent
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def get_config(
        variant,
        expl_env,
        eval_env,
        obs_dim,
        action_dim,
        replay_buffer,
):

    """
    Policy construction
    """

    M = variant['policy_kwargs']['layer_size']
    latent_dim = variant['policy_kwargs']['latent_dim']
    restrict_dim = variant['discriminator_kwargs']['restrict_input_size']

    control_policy = TanhGaussianPolicy(
        obs_dim=obs_dim + latent_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
        restrict_obs_dim=restrict_dim,
    )

    prior = torch.distributions.uniform.Uniform(
        -ptu.ones(latent_dim), ptu.ones(latent_dim),
    )

    policy = PriorLatentPolicy(
        policy=control_policy,
        prior=prior,
        unconditional=True,
    )

    qf1, qf2, target_qf1, target_qf2 = ppp.group_init(
        4,
        FlattenMlp,
        input_size=obs_dim + latent_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
        # restrict_dim=restrict_dim,
    )

    """
    Discriminator
    """

    discrim_kwargs = variant['discriminator_kwargs']
    discriminator = SkillDynamics(
        observation_size=obs_dim if restrict_dim == 0 else restrict_dim,
        action_size=action_dim,
        latent_size=latent_dim,
        normalize_observations=discrim_kwargs.get('normalize_observations', True),
        fix_variance=discrim_kwargs.get('fix_variance', True),
        fc_layer_params=[discrim_kwargs['layer_size']] * discrim_kwargs['num_layers'],
        # use_latents_as_delta=variant.get('use_latents_as_delta', False),
    )

    """
    Policy trainer
    """

    policy_trainer = SACTrainer(
        env=expl_env,
        policy=control_policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['policy_trainer_kwargs'],
    )

    """
    Setup of intrinsic control
    """

    latent_buffer = SimpleReplayBuffer(
        variant['generated_replay_buffer_size'],
        obs_dim + latent_dim,
        action_dim,
        dict(),
    )

    if variant['dads_type'] == 'onpolicy':
        trainer_class = DADSTrainer
    elif variant['dads_type'] == 'offpolicy':
        trainer_class = DADSOffTrainer
    else:
        raise NotImplementedError('dads_type not recognized')

    trainer = trainer_class(
        control_policy=control_policy,
        discriminator=discriminator,
        replay_buffer=replay_buffer,
        latent_buffer=latent_buffer,
        policy_trainer=policy_trainer,
        restrict_input_size=restrict_dim,
        **variant['trainer_kwargs'],
    )

    """
    Create config dict
    """

    config = dict()
    config.update(dict(
        trainer=trainer,
        exploration_policy=policy,
        evaluation_policy=policy,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        replay_buffer=replay_buffer,
        generated_replay_buffer=latent_buffer,
        prior=prior,
        control_policy=control_policy,
        latent_dim=latent_dim,
        policy_trainer=policy_trainer,
    ))
    config['algorithm_kwargs'] = variant.get('algorithm_kwargs', dict())
    config['offline_kwargs'] = variant.get('offline_kwargs', dict())

    """
    Special policy visualizations
    """

    env_name = variant['env_name']
    if env_name in ['Gridworld', 'ContinualAnt']:

        def plot_graph(epoch):
            fig = plt.figure()
            fig.set_size_inches(4, 4)

            if env_name == 'Gridworld':
                size = 1
            elif env_name == 'ContinualAnt':
                size = 10
            else:
                size = 1

            plt.xlim(-size, size)
            plt.ylim(-size, size)

            colors = np.array(sns.color_palette())

            num_lats = 3
            for i in range(num_lats):
                policy.fixed_latent = True
                policy.sample_latent()
                for ep in range(3):
                    path = rollout_with_latent(eval_env, policy,
                                               max_path_length=variant['algorithm_kwargs']['max_path_length'])
                    obs = path['observations']
                    plt.plot(obs[:,0], obs[:,1], linewidth=1, color=tuple(colors[i]), alpha=0.5)
                # print(obs[-1,0], obs[-1,1])
                policy.fixed_latent = False

            logger.savefig('graphs/grid_%d.png' % epoch, fig=fig)
            plt.close()

        config['algorithm_kwargs']['post_epoch_funcs'] = [plot_graph]

    return config
