import os
import time
from copy import deepcopy
import uuid
import pickle

import numpy as np
import pprint

import jax
import jax.numpy as jnp
import flax

import gym
import d4rl

import absl.app
import absl.flags

import d4rl_additions
from .offline_sac import OfflineSAC
from .conservative_sac import ConservativeSAC
from .pop_sac import POPSAC
from .replay_buffer import get_d4rl_dataset, get_pop_d4rl_dataset, subsample_batch
from .jax_utils import batch_to_jax, JaxRNG, next_rng, wrap_function_with_rng, no_nans
from .model import TanhGaussianPolicy, FullyConnectedQFunction, LinearDualFunction, SamplerPolicy
from .sampler import StepSampler, TrajSampler, TrajRenderer
from .utils import (
    Timer, define_flags_with_default, set_random_seed, print_flags,
    get_user_flags, prefix_metrics, WandBLogger
)
from .multi_data_loader import MultiDataLoader
from viskit.viskit_logging import logger, setup_logger

from .plotting import plot_lambda


FLAGS_DEF = define_flags_with_default(
    env='halfcheetah-medium-v2',
    alg='sac',
    max_traj_length=1000,
    seed=42,
    save_model=False,
    batch_size=256,

    reward_scale=1.0,
    reward_bias=0.0,
    normalize_reward=True,
    clip_action=0.999,

    policy_arch='256-256',
    qf_arch='256-256',
    orthogonal_init=False,
    policy_log_std_multiplier=1.0,
    policy_log_std_offset=-1.0,
    qf_feature_layer_norm=True,

    n_epochs=2000,
    n_train_step_per_epoch=1000,
    eval_period=10,
    eval_n_trajs=5,
    log_renders=False,
    render_n_trajs=1,

    sac=OfflineSAC.get_default_config(),
    cql=ConservativeSAC.get_default_config(),
    pop=POPSAC.get_default_config(),
    logging=WandBLogger.get_default_config(),

    load_data=MultiDataLoader.get_default_config(),   # If not null, loads some external data
    ongoing_plot=False,
)


def main(argv):
    FLAGS = absl.flags.FLAGS

    variant = get_user_flags(FLAGS, FLAGS_DEF)
    wandb_logger = WandBLogger(config=FLAGS.logging, variant=variant)
    setup_logger(
        variant=variant,
        exp_id=wandb_logger.experiment_id,
        seed=FLAGS.seed,
        base_log_dir=FLAGS.logging.output_dir,
        include_exp_prefix_sub_dir=False
    )

    set_random_seed(FLAGS.seed)
    rng_generator = JaxRNG(next_rng())

    if FLAGS.env.startswith("brax!"):
        import brax
        env_eval = brax.envs.create_gym_env(env_name=FLAGS.env[len("brax!"):])
    else:
        env_eval = gym.make(FLAGS.env).unwrapped
    eval_sampler = TrajSampler(env_eval, FLAGS.max_traj_length)
    renderer = TrajRenderer(env_eval, FLAGS.max_traj_length)

    dataset = get_d4rl_dataset(eval_sampler.env)
    dataset = d4rl_additions.modify_d4rl_dataset(dataset, env_eval)

    dataset['rewards'] = dataset['rewards'] * FLAGS.reward_scale + FLAGS.reward_bias
    if FLAGS.normalize_reward:
        dataset['rewards'] = dataset['rewards'] / np.sqrt((dataset['rewards'] ** 2).mean())
    dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)

    observation_dim = eval_sampler.env.observation_space.shape[0]
    action_dim = eval_sampler.env.action_space.shape[0]

    policy = TanhGaussianPolicy(
        observation_dim, action_dim, FLAGS.policy_arch, FLAGS.orthogonal_init,
        FLAGS.policy_log_std_multiplier, FLAGS.policy_log_std_offset
    )
    qf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.qf_arch, FLAGS.orthogonal_init,
                                 feature_layer_norm=FLAGS.qf_feature_layer_norm)

    if FLAGS.sac.target_entropy >= 0.0:
        FLAGS.sac.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item()

    alg_config = deepcopy(FLAGS.sac)
    use_reweighting = False
    if FLAGS.alg == 'sac':
        sac = OfflineSAC(alg_config, policy, qf)
    elif FLAGS.alg == 'bc':
        alg_config.unlock().update({
            'fixed_qf': True,
            'bc_epochs': FLAGS.n_epochs + 1,
        })
        sac = OfflineSAC(alg_config, policy, qf)
    elif FLAGS.alg == 'cql':
        alg_config.unlock().update(FLAGS.cql)
        alg_config.lock()
        sac = ConservativeSAC(alg_config, policy, qf)
    elif FLAGS.alg == 'pop':
        dualf = LinearDualFunction(rank=FLAGS.pop.dual_rank,
                                   dual_min=FLAGS.pop.dual_min_val, dual_max=FLAGS.pop.dual_max_val)
        gf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.pop.g_arch, scale=1e-5, q_min=-1, q_max=1)
        
        alg_config.unlock().update(FLAGS.pop)
        alg_config.lock()
        sac = POPSAC(alg_config, policy, qf, dualf, gf)
        use_reweighting = True
    else:
        raise ValueError(f"Unknown algorithm: {FLAGS.alg}")
    
    assert no_nans(sac.train_params['policy'])
    assert no_nans(sac.train_params['qf1'])

    sampler_policy = SamplerPolicy(sac.policy, sac.train_params['policy'])
    # sampler_policy = SamplerPolicy(sac.policy, sac.train_params['policy_tilde'])

    viskit_metrics = {}
    for epoch in range(FLAGS.n_epochs):
        metrics = {'epoch': epoch}

        with Timer() as train_timer:
            for batch_idx in range(FLAGS.n_train_step_per_epoch):
                batch = batch_to_jax(subsample_batch(dataset, FLAGS.batch_size))
                metrics.update(sac.train(batch, epoch))
                assert no_nans(sac.train_params['policy'])
                # if not no_nans(sac.train_params['policy_tilde']):
                #     print(metrics)
                # assert no_nans(sac.train_params['policy_tilde'])
                assert no_nans(sac.train_params['qf1'])
                assert no_nans(sac.train_params['qf2'])

        with Timer() as eval_timer:
            if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0:
                assert no_nans(sac.train_params['policy'])
                # assert no_nans(sac.train_params['policy_tilde'])
                trajs = eval_sampler.sample(
                    sampler_policy.update_params(sac.train_params['policy']),
                    # sampler_policy.update_params(sac.train_params['policy_tilde']),
                    FLAGS.eval_n_trajs, deterministic=True
                )

                metrics['average_return'] = np.mean([np.sum(t['rewards']) for t in trajs])
                metrics['average_traj_length'] = np.mean([len(t['rewards']) for t in trajs])
                metrics['average_normalizd_return'] = np.mean(
                    [getattr(eval_sampler.env, "get_normalized_score", lambda x: x)(np.sum(t['rewards'])) for t in trajs]
                )
                if FLAGS.save_model:
                    save_data = {'sac': sac, 'variant': variant, 'epoch': epoch}
                    wandb_logger.save_pickle(save_data, f'model_{epoch}.pkl')
                    
                if FLAGS.log_renders:
                    frames = renderer.sample(sampler_policy, FLAGS.render_n_trajs, deterministic=True)
                    wandb_logger.log_video(frames, f'render_{epoch}.gif')
        
        with Timer() as calc_A_timer:
            if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0:
                if use_reweighting:
                    F_mu1, F_mu2, F_q1, F_q2, pop_objective1, pop_objective2 = \
                        sac.get_F_matrices(dataset, epoch, use_reweighting=use_reweighting)
                    
                    eig_vals1, _ = np.linalg.eig(F_q1)
                    eig_vals2, _ = np.linalg.eig(F_q2)
                    metrics['norm_F_q1_min_eig'] = np.real(eig_vals1).min() / np.real(eig_vals1).max()
                    metrics['norm_F_q2_min_eig'] = np.real(eig_vals2).min() / np.real(eig_vals2).max()
                    
                    metrics['pop_objective1'] = pop_objective1
                    metrics['pop_objective2'] = pop_objective2
                else:
                    F_mu1, F_mu2 = sac.get_F_matrices(dataset, epoch, use_reweighting=use_reweighting)
                
                eig_vals1, _ = np.linalg.eig(F_mu1)
                eig_vals2, _ = np.linalg.eig(F_mu2)
                metrics['norm_F_mu1_min_eig'] = np.real(eig_vals1).min() / np.real(eig_vals1).max()
                metrics['norm_F_mu2_min_eig'] = np.real(eig_vals2).min() / np.real(eig_vals2).max()
                
                if FLAGS.ongoing_plot:
                    @wrap_function_with_rng(rng_generator())
                    def forward_policy(rng, *args, **kwargs):
                        return sac.policy.apply(
                            *args, **kwargs,
                            rngs=JaxRNG(rng)(sac.policy.rng_keys())
                        )
                    plot_lambda(
                        lambda obs, act: sac.qf.apply(sac._train_states['qf1'].params, obs, act),
                        lambda obs, act: sac.qf.apply(sac._train_states['qf2'].params, obs, act),
                        lambda obs: forward_policy(sac._train_states['policy'].params, obs),
                        dataset['observations'][:1000],
                        dataset['next_observations'][:1000],
                        dataset['actions'][:1000]
                    )

        metrics['train_time'] = train_timer()
        metrics['eval_time'] = eval_timer()
        metrics['calc_A_time'] = calc_A_timer()
        metrics['epoch_time'] = train_timer() + eval_timer()
        wandb_logger.log(metrics)
        viskit_metrics.update(metrics)
        logger.record_dict(viskit_metrics)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)

    if FLAGS.save_model:
        save_data = {'sac': sac, 'variant': variant, 'epoch': epoch}
        wandb_logger.save_pickle(save_data, 'model.pkl')


if __name__ == '__main__':
    absl.app.run(main)
