import gym

import absl.app
import absl.flags

from icml_supplies.contextformer.JaxCQL.TD3_BC import ConservativeSAC
from icml_supplies.contextformer.JaxCQL.replay_buffer import subsample_batch, concatenate_batches, get_gym_d4rl_dataset_with_mc_calculation, get_hand_dataset_with_mc_calculation
from icml_supplies.contextformer.JaxCQL.jax_utils import *
from icml_supplies.contextformer.JaxCQL.model import TanhGaussianPolicy, FullyConnectedQFunction, SamplerPolicy
from icml_supplies.contextformer.JaxCQL.sampler import TrajSampler
from viskit.logging import logger, setup_logger
from icml_supplies.contextformer.JaxCQL.replay_buffer import ReplayBuffer

#from tensorboardX import Summerywriter
from tensorboardX import SummaryWriter

from icml_supplies.contextformer.JaxCQL.vcse import *


FLAGS_DEF = define_flags_with_default(
    env='antmaze-medium-diverse-v2',
    seed=42,
    save_model=False,
    batch_size=256,

    reward_scale=1.0,
    reward_bias=0.0,
    clip_action=0.99999,

    policy_arch='256-256',
    qf_arch='256-256',
    orthogonal_init=True,
    policy_log_std_multiplier=1.0,
    policy_log_std_offset=-1.0,

    # Total grad_steps of offline pretrain will be (n_train_step_per_epoch_offline * n_pretrain_epochs)
    n_train_step_per_epoch_offline=1000,
    n_pretrain_epochs=1000,
    offline_eval_every_n_epoch=10,

    max_online_env_steps=1e6,
    online_eval_every_n_env_steps=1000,

    eval_n_trajs=5,
    replay_buffer_size=1000000,
    mixing_ratio=-1.0,
    use_cql=True,
    online_use_cql=True,
    cql_min_q_weight=5.0,
    cql_min_q_weight_online=-1.0,
    enable_calql=True, # Turn on for Cal-QL
    reward_bins=10,
    n_online_traj_per_epoch=1,
    online_utd_ratio=1,
    logging_path='tmp',
    lambda1=1,
    cql=ConservativeSAC.get_default_config(),
    logging=WandBLogger.get_default_config(),
)

#writer=SummaryWriter(FLAGS.logging_path)
def main(argv):
    FLAGS = absl.flags.FLAGS
    variant = get_user_flags(FLAGS, FLAGS_DEF)
    wandb_logger = WandBLogger(config=FLAGS.logging, variant=variant)
    setup_logger(
        variant=variant,
        exp_id=wandb_logger.experiment_id,
        seed=FLAGS.seed,
        base_log_dir=FLAGS.logging.output_dir,
        include_exp_prefix_sub_dir=False
    )
    writer=SummaryWriter(FLAGS.logging_path)
    reward_modification=VCSE(FLAGS.reward_bins,'cuda') # entropy_rewards

    def modify_reward(framework,reward,ori_obs_batch,state,action,model,mode,wegiht=1):
        if mode=='critic_cond':
            cond=jax.lax.stop_gradient(framework.compute_q(model, state, action))
            # cond=forward_qf(model,state,action)
        else:
            cond=forward_qf(model,state)
        reward_mode, n_v,n_s, eps, state_norm, value_norm=reward_modification(torch.from_numpy(ori_obs_batch).to('cuda'),
                                                                            torch.from_numpy(np.array(batch_to_array(cond))).to('cuda').view(-1,1))
        modified_reward=reward+reward_mode.detach().cpu().numpy().reshape((reward.shape[0],))*wegiht
        return modified_reward

    if FLAGS.env in ["pen-binary-v0", "door-binary-v0", "relocate-binary-v0"]:
        dataset = get_hand_dataset_with_mc_calculation(FLAGS.env, gamma=FLAGS.cql.discount, reward_scale=FLAGS.reward_scale, reward_bias=FLAGS.reward_bias, clip_action=FLAGS.clip_action)
        use_goal = True
    elif FLAGS.env in ['halfcheetah-medium-v2','hopper-medium-v2','walker2d-medium-v2',
                       'halfcheetah-medium-replay-v2','hopper-medium-replay-v2','walker2d-medium-replay-v2',
                       'ant-medium-v2','ant-medium-replay-v2']:
        dataset = get_gym_d4rl_dataset_with_mc_calculation(FLAGS.env, FLAGS.reward_scale, FLAGS.reward_bias, FLAGS.clip_action, gamma=FLAGS.cql.discount)
        use_goal=False
    else:
        dataset = get_d4rl_dataset_with_mc_calculation(FLAGS.env, FLAGS.reward_scale, FLAGS.reward_bias, FLAGS.clip_action, gamma=FLAGS.cql.discount)
        use_goal=False
    assert dataset["next_observations"].shape == dataset["observations"].shape

    set_random_seed(FLAGS.seed)
    eval_sampler = TrajSampler(gym.make(FLAGS.env).unwrapped, use_goal, gamma=FLAGS.cql.discount)
    train_sampler = TrajSampler(gym.make(FLAGS.env).unwrapped, use_goal, use_mc=True, gamma=FLAGS.cql.discount, reward_scale=FLAGS.reward_scale, reward_bias=FLAGS.reward_bias,)
    replay_buffer = ReplayBuffer(FLAGS.replay_buffer_size)

    observation_dim = eval_sampler.env.observation_space.shape[0]
    action_dim = eval_sampler.env.action_space.shape[0]

    policy = TanhGaussianPolicy(
        observation_dim, action_dim, FLAGS.policy_arch, FLAGS.orthogonal_init,
        FLAGS.policy_log_std_multiplier, FLAGS.policy_log_std_offset
    )

    qf = FullyConnectedQFunction(observation_dim, action_dim, FLAGS.qf_arch, FLAGS.orthogonal_init)

    if FLAGS.cql.target_entropy >= 0.0:
        FLAGS.cql.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item()

    sac = ConservativeSAC(FLAGS.cql, policy, qf)
    sampler_policy = SamplerPolicy(sac.policy, sac.train_params['policy'])

    viskit_metrics = {}
    n_train_step_per_epoch = FLAGS.n_train_step_per_epoch_offline
    cql_min_q_weight = FLAGS.cql_min_q_weight
    enable_calql=FLAGS.enable_calql
    use_cql=FLAGS.use_cql
    mixing_ratio = FLAGS.mixing_ratio

    total_grad_steps=0
    is_online=False
    online_eval_counter=-1
    do_eval=False
    online_rollout_timer = None
    train_timer = None
    epoch = 0
    train_metrics = None
    expl_metrics = None
    while True:
        metrics = {'epoch': epoch}

        if epoch == FLAGS.n_pretrain_epochs:
            is_online = True
            if FLAGS.cql_min_q_weight_online >= 0:
                print(f"changing cql alpha from {cql_min_q_weight} to {FLAGS.cql_min_q_weight_online}")
                cql_min_q_weight = FLAGS.cql_min_q_weight_online

            if not FLAGS.online_use_cql and use_cql:
                print("truning off cql during online phase and use sac")
                use_cql = False
                if sac.config.cql_lagrange:
                    model_keys = list(sac.model_keys)
                    model_keys.remove('log_alpha_prime')
                    sac._model_keys = tuple(model_keys)

        """
        Do evaluations when
        1. epoch = 0 to get initial performance
        2. every FLAGS.offline_eval_every_n_epoch for offline phase
        3. epoch == FLAGS.n_pretrain_epochs to get offline pre-trained performance
        4. every FLAGS.online_eval_every_n_env_steps for online phase
        5. when replay_buffer.total_steps >= FLAGS.max_online_env_steps to get final fine-tuned performance
        """
        do_eval = (epoch == 0 or (not is_online and epoch % FLAGS.offline_eval_every_n_epoch == 0) or (epoch == FLAGS.n_pretrain_epochs) or (is_online and replay_buffer.total_steps // FLAGS.online_eval_every_n_env_steps > online_eval_counter) or (replay_buffer.total_steps >= FLAGS.max_online_env_steps))
            
        with Timer() as eval_timer:
            if do_eval:
                print(f"Starting Evaluation for Epoch {epoch}")
                trajs = eval_sampler.sample(
                    sampler_policy.update_params(sac.train_params['policy']),
                    FLAGS.eval_n_trajs, deterministic=True
                )

                metrics['evaluation/average_return'] = np.mean([np.sum(t['rewards']) for t in trajs])
                
                metrics['evaluation/average_traj_length'] = np.mean([len(t['rewards']) for t in trajs])
                print(metrics) 
                if is_online:
                    writer.add_scalar('total/average_return',metrics['evaluation/average_return'],epoch)
                    writer.add_scalar('total/step_test',replay_buffer.total_steps,replay_buffer.total_steps)
                    writer.add_scalar('total/average_return_step',metrics['evaluation/average_return'],replay_buffer.total_steps)
                else:
                    writer.add_scalar('total/average_return',metrics['evaluation/average_return'],epoch)


                if use_goal:
                    # for adroit envs
                    metrics['evaluation/goal_achieved_rate'] = np.mean([1 in t['goal_achieved'] for t in trajs])
                    if is_online:    
                        writer.add_scalar('offline/goal_achieved_rate_step',metrics['evaluation/goal_achieved_rate'],replay_buffer.total_steps)
                        writer.add_scalar('offline/goal_achieved_rate',metrics['evaluation/goal_achieved_rate'],epoch)
                    else:
                        writer.add_scalar('offline/goal_achieved_rate',metrics['evaluation/goal_achieved_rate'],epoch)

                else:
                    # for d4rl envs
                    metrics['evaluation/average_normalized_return'] = np.mean([eval_sampler.env.get_normalized_score(np.sum(t['rewards'])) for t in trajs])
                    if is_online:
                        writer.add_scalar('offline/average_normalized_return',metrics['evaluation/average_normalized_return'],epoch)
                        #writer.add_scalar('total/step_test',replay_buffer.total_steps,replay_buffer.total_steps)
                        writer.add_scalar('offline/average_normalized_return_step',metrics['evaluation/average_normalized_return'],replay_buffer.total_steps)
                    else:
                        writer.add_scalar('offline/average_normalized_return',metrics['evaluation/average_normalized_return'],epoch)
                if is_online:
                    online_eval_counter = replay_buffer.total_steps // FLAGS.online_eval_every_n_env_steps

                if FLAGS.save_model:
                    save_data = {'sac': sac, 'variant': variant, 'epoch': epoch}
                    wandb_logger.save_pickle(save_data, 'model.pkl')   
    
        metrics['grad_steps'] = total_grad_steps
        if is_online:
            metrics['env_steps'] = replay_buffer.total_steps
        metrics['epoch'] = epoch
        metrics['online_rollout_time'] = 0 if online_rollout_timer is None else online_rollout_timer()
        metrics['train_time'] = 0 if train_timer is None else train_timer()
        metrics['eval_time'] = eval_timer()
        metrics['epoch_time'] = eval_timer() if train_timer is None else train_timer() + eval_timer()
        if FLAGS.n_pretrain_epochs >= 0:
            metrics['mixing_ratio'] = mixing_ratio
        if train_metrics is not None:
            metrics.update(train_metrics)
        if expl_metrics is not None:
            metrics.update(expl_metrics)
        
        #wandb_logger.log(metrics)
        viskit_metrics.update(metrics)
        logger.record_dict(viskit_metrics)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)

        if replay_buffer.total_steps >= FLAGS.max_online_env_steps:
            print("Finished Training")
            break

        with Timer() as online_rollout_timer:
            if is_online:
                print("collecting online trajs:", FLAGS.n_online_traj_per_epoch)
                trajs = train_sampler.sample(
                    sampler_policy.update_params(sac.train_params['policy']),
                    n_trajs=FLAGS.n_online_traj_per_epoch, deterministic=False, replay_buffer=replay_buffer
                )
                expl_metrics = {}
                expl_metrics['exploration/average_return'] = np.mean([np.sum(t['rewards']) for t in trajs])
                expl_metrics['exploration/average_traj_length'] = np.mean([len(t['rewards']) for t in trajs])
                writer.add_scalar('online/average_return_step',expl_metrics['exploration/average_return'],replay_buffer.total_steps)
                if use_goal:
                    expl_metrics['exploration/goal_achieved_rate'] = np.mean([1 in t['goal_achieved'] for t in trajs])
                    writer.add_scalar('online/goal_achieved_rate_step',expl_metrics['exploration/goal_achieved_rate'],replay_buffer.total_steps)
                    writer.add_scalar('online/goal_achieved_rate',expl_metrics['exploration/goal_achieved_rate'],epoch)

        if train_timer is None:
            print("jit compiling train function: will take a while")
            
        with Timer() as train_timer:
            if FLAGS.n_pretrain_epochs >= 0 and epoch >= FLAGS.n_pretrain_epochs and FLAGS.online_utd_ratio > 0:
                n_train_step_per_epoch = np.sum([len(t["rewards"]) for t in trajs]) *  FLAGS.online_utd_ratio
            if FLAGS.n_pretrain_epochs >= 0:
                if FLAGS.mixing_ratio >= 0:
                    mixing_ratio = FLAGS.mixing_ratio
                else:
                    mixing_ratio = dataset['rewards'].shape[0] / (dataset['rewards'].shape[0] + replay_buffer.total_steps)
                batch_size_offline = int(FLAGS.batch_size * mixing_ratio)
                batch_size_online = FLAGS.batch_size - batch_size_offline

            for _ in range(n_train_step_per_epoch):
                if is_online:
                    # mix offline and online buffer
                    # offline_batch = subsample_batch(dataset, batch_size_offline)
                    # online_batch = replay_buffer.sample(batch_size_online)
                    # batch = concatenate_batches([offline_batch, online_batch])
                    # batch = batch_to_jax(batch)
                    batch = replay_buffer.sample(batch_size_online)
                    reward_batch=batch['rewards']
                    ori_obs_batch=batch['observations']
                    ori_batch=deepcopy(batch)
                    batch = batch_to_jax(batch)
                    # reward,state,action,model,mode
                    entropy_reward=modify_reward(sac,
                                                 reward_batch,
                                                 ori_obs_batch,
                                                 batch['observations'],
                                                 batch['actions'],
                                                 model=[sac.qf1_params,sac.qf2_params],
                                                 mode='critic_cond',
                                                 wegiht=FLAGS.lambda1)
                    ori_batch['rewards']=entropy_reward.astype(np.float32)
                    offline_batch = subsample_batch(dataset, batch_size_offline)
                    #for k in offline_batch:
                        #try:
                            #print(k,ori_batch[k].shape,offline_batch[k].shape)
                        #except:
                            #continue
                    batch = concatenate_batches([offline_batch, ori_batch])
                    batch = batch_to_jax(ori_batch)
                else:
                    # pure offline
                    batch = batch_to_jax(subsample_batch(dataset, FLAGS.batch_size))
                train_metrics = prefix_metrics(sac.train(batch, use_cql=use_cql, cql_min_q_weight=cql_min_q_weight, enable_calql=enable_calql), 'sac')
            total_grad_steps += n_train_step_per_epoch
        epoch += 1

if __name__ == '__main__':
    absl.app.run(main)
