import functools

import d4rl
import gym
from absl import app
from absl import flags
import acme
from acme.agents.jax import actor_core as actor_core_lib
from acme.agents.jax import actors
from acme.jax import variable_utils
from acme.utils import counting
import jax
from ml_collections import config_flags
import numpy as np
import optax
import tqdm

from otr import dataset_utils
from otr import evaluation
from otr import experiment_utils
from otr.agents import iql
from otr.agents.otil import rewarder as rewarder_lib
import reward_learning.TREX_ensemble_modular as Rewardlearner
from otr.dataset_utils import relabel_rewards as oracle_relabel_rewards
import copy

_CONFIG = config_flags.DEFINE_config_file("config", "configs/otr_iql_mujoco.py")
_WORKDIR = flags.DEFINE_string('workdir', '/tmp/otr', '')
_REWARD_FILE = flags.DEFINE_string('reward_file', 'None', '')


def relabel_rewards(rewarder, trajectory):
    rewards = rewarder.compute_offline_rewards(trajectory)
    relabeled_transitions = []
    for transition, reward in zip(trajectory, rewards):
        relabeled_transitions.append(transition._replace(reward=reward))
    return relabeled_transitions


def compute_iql_reward_scale(trajs):
    """Rescale rewards based on max/min from the dataset.
    This is also used in the original IQL implementation.
    """
    trajs = trajs.copy()

    def compute_returns(tr):
        return sum([step.reward for step in tr])

    trajs.sort(key=compute_returns)
    reward_scale = 1000.0 / (
            compute_returns(trajs[-1]) - compute_returns(trajs[0])+1e-5)
    return reward_scale



def find_top_k_trajs(offline_dataset_name,offline_traj,k):
        if "antmaze" in offline_dataset_name:
            # 1/Distance (from the bottom-right corner) times return
            returns = [
                    sum([t.reward
                             for t in traj]) /
                    (1e-4 + np.linalg.norm(traj[0].observation[:2]))
                    for traj in offline_traj
            ]
        else:
            returns = [sum([t.reward for t in traj]) for traj in offline_traj]
        idx = np.argpartition(returns, -k)[-k:]
        demo_returns = [returns[i] for i in idx]
        print(f"demo returns {demo_returns}, mean {np.mean(demo_returns)}")
        expert_demo = [offline_traj[i] for i in idx]
        # expert_demo = [offline_traj[i][-50:] for i in idx]
        demo_returns = [sum([t.reward for t in traj]) for traj in expert_demo]
        print(f"demo slice returns {demo_returns}, mean {np.mean(demo_returns)}")
        return expert_demo, returns,idx


def find_above_average_trajs(offline_dataset_name, offline_traj, k):
    if "antmaze" in offline_dataset_name:
        # 1/Distance (from the bottom-right corner) times return
        returns = [
            sum([t.reward
                 for t in traj]) /
            (1e-4 + np.linalg.norm(traj[0].observation[:2]))
            for traj in offline_traj
        ]
    else:
        returns = [sum([t.reward for t in traj]) for traj in offline_traj]
    mean_return = np.mean(returns)
    median_return = np.median(returns)
    idx = []
    for i in range(len(returns)):
        if returns[i]>=median_return:
            idx.append(i)

    return  idx

def otr_relabel(original_offline_traj,offline_dataset_name,config,expert_demo):
    max_length = 0
    for t in original_offline_traj:
        max_length = max(max_length, len(t))
    print('max length:',max_length)
    episode_length = 1000
    if "maze2d" in offline_dataset_name:
        episode_length = max_length + 10
    if config.squashing_fn == 'linear':
        squashing_fn = functools.partial(
            rewarder_lib.squashing_linear, alpha=config.alpha)
    elif config.squashing_fn == 'exponential':
        # TODO: Make config key required for OTIL
        if config.get("normalize_by_atom", True):
            atom_size = expert_demo[0][0].observation.shape[0]
        else:
            atom_size = 1.0
        squashing_fn = functools.partial(
            rewarder_lib.squashing_exponential,
            alpha=config.alpha,
            beta=config.beta * episode_length / atom_size)
    else:
        raise ValueError(f'Unknown squashing fn {config.squashing_fn}')
    rewarder = rewarder_lib.OTILRewarder(
        expert_demo, episode_length=episode_length, squashing_fn=squashing_fn)

    offline_traj = copy.deepcopy(original_offline_traj)
    relabeled_trajectories = []
    for i in tqdm.trange(len(offline_traj)):  # pylint: disable=all
        relabeled_traj = relabel_rewards(rewarder, offline_traj[i])
        relabeled_trajectories.append(relabeled_traj)
    if "antmaze" in offline_dataset_name:
        reward_scale = compute_iql_reward_scale(relabeled_trajectories)
        reward_bias = -2.0
    else:
        reward_scale = compute_iql_reward_scale(relabeled_trajectories)
        reward_bias = 0.0
    relabeled_transitions = dataset_utils.merge_trajectories(
        relabeled_trajectories)


    relabeled_transitions = relabeled_transitions._replace(
        reward=relabeled_transitions.reward * reward_scale + reward_bias)
    return relabeled_transitions






def main(_):
    config = _CONFIG.value
    offline_dataset_name = config.offline_dataset_name
    print(offline_dataset_name)
    workdir = _WORKDIR.value
    log_to_wandb = config.log_to_wandb
    reward_file = _REWARD_FILE.value

    wandb_kwargs = {
            'project': config.wandb_project,
            'entity': config.wandb_entity,
            'config': config.to_dict(),
    }

    logger_factory = experiment_utils.LoggerFactory(
            workdir=workdir,
            log_to_wandb=log_to_wandb,
            wandb_kwargs=wandb_kwargs,
            learner_time_delta=10,
            evaluator_time_delta=0)

    if 'metaworld' not in config.offline_dataset_name:
        env = gym.make(config.offline_dataset_name)
        if "maze" in config.offline_dataset_name:
            original_dataset = dataset_utils.qlearning_dataset_with_timeouts(env,
                                                                             interval=False if 'maze2d' in config.offline_dataset_name else False)
            # dataset = relabel_rewards(env,dataset,name)
        else:
            original_dataset = d4rl.qlearning_dataset(env)
        dataset_with_relabelled_terminals_and_rewards = original_dataset#oracle_relabel_rewards(env, original_dataset,
                                                        #                       config.offline_dataset_name)
    else:
        dataset_name = config.offline_dataset_name.split('_')[1]
        import metaworld
        ml1 = metaworld.MT1(dataset_name, seed=1337)  # Construct the benchmark, sampling tasks

        env = ml1.train_classes[dataset_name]()  # Create an environment with task
        # print(ml1.train_tasks)
        env.train_tasks = ml1.train_tasks
        # task = 0
        task = ml1.train_tasks[0]
        env.set_task(task)
        env._freeze_rand_vec = False
        dataset_with_relabelled_terminals_and_rewards = np.load('./data/'+dataset_name+'/data_randgoal_08_50_08_batch.npy',allow_pickle=True).tolist()#
        # dataset_with_relabelled_terminals_and_rewards = np.load(
        #         '/data2/zj/Offline-MetaRL/dataset/' + dataset_name + '_dataset_finalevalsample.npy', allow_pickle=True).tolist()

        # dataset_with_relabelled_terminals_and_rewards = np.load(
        #     '/data2/zj/Offline-MetaRL/dataset/' + dataset_name + '_dataset.npy', allow_pickle=True).tolist()
        # dataset_with_relabelled_terminals_and_rewards['rewards'] = dataset_with_relabelled_terminals_and_rewards['rewards'][:,0]
        # dataset_with_relabelled_terminals_and_rewards['terminals'] = dataset_with_relabelled_terminals_and_rewards[
        #                                                                'terminals'][:, 0]

        # indices = np.zeros_like(dataset_with_relabelled_terminals_and_rewards['terminals'])
        # for i in range(indices.shape[0]):
        #     if (i+1)%500==0:
        #         # print(i)
        #         dataset_with_relabelled_terminals_and_rewards['terminals'][i]=1
    # original_dataset = env.get_dataset()#data_randgoalharderer



    dones_float = np.zeros_like(dataset_with_relabelled_terminals_and_rewards['rewards'])

    # if 'metaworld' not in config.offline_dataset_name:
    #     for i in range(len(dones_float) - 1):
    #         distance = np.linalg.norm(dataset_with_relabelled_terminals_and_rewards['observations'][i + 1] -
    #                                   dataset_with_relabelled_terminals_and_rewards['next_observations'][i]
    #                                   )
    #         if distance > 1e-6 or dataset_with_relabelled_terminals_and_rewards['terminals'][i] == 1.0:  # or ('maze2d' in name and distance<1e-6):
    #             dones_float[i] = 1
    #         else:
    #             dones_float[i] = 0
    #     dones_float[-1] = 1
    #
    #     dataset_with_relabelled_terminals_and_rewards['terminals'] = dones_float

    if 'realterminals' in  dataset_with_relabelled_terminals_and_rewards:
        # We updated terminals in the dataset, but continue using
        # the old terminals for consistency with original IQL.
        masks = 1.0 - dataset_with_relabelled_terminals_and_rewards['realterminals'].astype(np.float32)
    else:
        masks = 1.0 - dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32)


    # trajs_with_relabelled_terminals_and_rewards =
    num_ensembles = 1
    prefix='50_5'
    reward_learner = Rewardlearner.RewardLearner(env_name=config.offline_dataset_name,seed=config.seed,num_queries_per_iter=5,num_ensembles=num_ensembles,env=env,dataset=dataset_with_relabelled_terminals_and_rewards,traj_length=50,prefix=prefix)

    initial_reward_arry = reward_learner.init_reward_model(initial_pairs=5,num_iter=20,retrain_num_iter=20)
    current_learned_reward = initial_reward_arry
    # current_learned_reward = np.random.rand(*current_learned_reward.shape)
    # import matplotlib
    # import tkinter
    # matplotlib.use('TkAgg')
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.scatter(initial_reward_arry,dataset_with_relabelled_terminals_and_rewards['rewards'])
    # plt.show()
    # print('plot!')



    for i in range(0):

        learned_trajs = dataset_utils.split_into_trajectories(
                observations=dataset_with_relabelled_terminals_and_rewards['observations'].astype(np.float32),
                actions=dataset_with_relabelled_terminals_and_rewards['actions'].astype(np.float32),
                rewards=current_learned_reward[0].astype(np.float32),
                masks=masks,
                dones_float=dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32),
                next_observations=dataset_with_relabelled_terminals_and_rewards['next_observations'].astype(np.float32))

        # top_k_trajs, learned_returns = find_top_k_trajs(config.offline_dataset_name,learned_trajs,config.k)
        # otr_dataset = otr_relabel(original_offline_traj=learned_trajs, offline_dataset_name=config.offline_dataset_name,config=config,expert_demo=top_k_trajs)
        # print(otr_dataset.reward)
        # print(current_learned_reward.shape,otr_dataset.reward.shape)
        # new_queries = reward_learner.find_new_queries(current_learned_reward,otr_dataset.reward)
        query_list = []
        # new_queries = reward_learner.find_new_queries_disagreement(current_learned_reward, np.zeros_like(current_learned_reward))
        for i in range(num_ensembles):
            new_queries = reward_learner.find_new_queries(current_learned_reward[i,:], np.zeros_like(current_learned_reward[i,:]))
            query_list.append(new_queries)
        current_learned_reward = reward_learner.learn_reward_later(query_list)


    # current_learned_reward = np.zeros_like(current_learned_reward)
    # dataset = get_demonstration_dataset(config)


    current_learned_reward = np.mean(current_learned_reward,0)+0.3*np.std(current_learned_reward,0)
    # current_learned_reward = np.load(
    #     '/data3/zj/oprl/reward_learning/rewards2/ensemble_' + dataset_name + '_initial_pairs_1_num_queries_3_num_iter_1_retrain_num_iter_1_voi_dis_seed_101371_round_num_1.npy',
    #     allow_pickle=True)
    current_learned_reward = np.load(
        '/data3/zj/tmp/oprl/reward_learning/rewards/ensemble_' + config.offline_dataset_name + '_initial_pairs_1_num_queries_1_num_iter_20_retrain_num_iter_20_voi_dis_seed_1011320_round_num_3.npy',
        allow_pickle=True)
    current_learned_reward -= np.min(current_learned_reward)
    # current_learned_reward = np.load('./pt_rewards/'+config.offline_dataset_name+'.npy')

    # current_learned_reward = np.random.rand(*current_learned_reward.shape)*2-1


    representations = reward_learner.compute_dataset_representation()
    learned_trajs_true = dataset_utils.split_into_trajectories(
        observations=dataset_with_relabelled_terminals_and_rewards['observations'].astype(np.float32),#representations.astype(np.float32),#
        actions=dataset_with_relabelled_terminals_and_rewards['actions'].astype(np.float32),
        rewards=dataset_with_relabelled_terminals_and_rewards['rewards'].astype(np.float32),#current_learned_reward.astype(np.float32),################################
        masks=masks,
        dones_float=dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32),
        next_observations=dataset_with_relabelled_terminals_and_rewards['next_observations'].astype(np.float32),
        extras=np.zeros_like(dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32)))

    learned_trajs_learned = dataset_utils.split_into_trajectories(
        observations=dataset_with_relabelled_terminals_and_rewards['observations'].astype(np.float32),
        # representations.astype(np.float32),#
        actions=dataset_with_relabelled_terminals_and_rewards['actions'].astype(np.float32),
        rewards=current_learned_reward.astype(np.float32),
        # ################################
        masks=masks,
        dones_float=dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32),
        next_observations=dataset_with_relabelled_terminals_and_rewards['next_observations'].astype(np.float32),
    extras=np.zeros_like(dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32)))

    # returns = [sum([t.reward for t in traj]) for traj in learned_trajs_true]
    # print(len(returns),len(learned_trajs_true[0]),returns[0])
    # print(np.sum(dataset_with_relabelled_terminals_and_rewards['terminals'].astype(np.float32)),len(learned_trajs_true),len(learned_trajs_learned)) # learned_trajs_true
    # _, learned_returns,indices = find_top_k_trajs(config.offline_dataset_name, learned_trajs_learned, config.k)
    # above_average_indices = find_above_average_trajs(config.offline_dataset_name, learned_trajs_learned, config.k)
    # top_k_trajs = [learned_trajs_learned[i] for i in indices]
    # top_10_trajs = [learned_trajs_learned[i] for i in indices[-10:]]
    # #
    expectile_mask = np.zeros_like(dataset_with_relabelled_terminals_and_rewards['rewards']).astype(np.float32)
    # for indice in indices:
    #     expectile_mask[indice*500:(indice+1)*500] = 1
    # top_k_dataset = dataset_utils.merge_trajectories(
    #     top_k_trajs)

    # dataset = top_k_dataset
    # dataset = dataset._replace(
    #     reward=np.ones_like(dataset.reward))


    # dataset = otr_relabel(original_offline_traj=learned_trajs_learned,
    #                           offline_dataset_name=config.offline_dataset_name, config=config, expert_demo=top_10_trajs)

    dataset =  dataset_utils.merge_trajectories(
        learned_trajs_learned)
    # dataset_with_relabelled_terminals_and_rewards['rewards'] = otr_rewards

    # import matplotlib
    # matplotlib.use('TkAgg')
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.scatter(otr_dataset.reward,dataset_with_relabelled_terminals_and_rewards['rewards'])
    # plt.xlabel('OTR reward')
    # plt.ylabel('True reward')
    # plt.show()

    # dataset = otr_dataset
    # dataset = dataset._replace(
    #     observation=dataset_with_relabelled_terminals_and_rewards['observations'])
    dataset = dataset._replace(
        reward=current_learned_reward.astype(np.float32))
    if config.use_dataset_reward:
        dataset = dataset._replace(
                reward= dataset_with_relabelled_terminals_and_rewards['rewards'])
    # dataset = dataset._replace(
    #     extras=expectile_mask)
    r = copy.deepcopy(dataset_with_relabelled_terminals_and_rewards['rewards'])
    # r=copy.deepcopy(current_learned_reward)
    averaged = ((r-np.mean(r))>=0).astype(np.float32)
    # dataset =  dataset._replace(
    #     extras= averaged)
    normalized_reward = (r-np.min(r))/(np.max(r)-np.min(r))
    # normalized_reward = (np.exp(3*normalized_reward)-1)/(np.exp(3)-1)
    # normalized_discount = normalized_reward*0.49+0.5
    # dataset =  dataset._replace(
    #     extras= normalized_discount)
    print(dataset.observation.shape,dataset.action.shape,dataset.reward.shape)

    # dataset = dataset._replace(
    #         reward=current_learned_reward)


        # dataset = dataset_with_relabelled_terminals_and_rewards
    # dataset.reward = np.zeros_like(dataset.reward)
    # dataset.reward = np.zeros_like(dataset.reward)
    # dataset = dataset._replace(reward= np.ones_like(dataset.reward))

    # rewards = np.load(
        # '/data3/zj/oprl/reward_learning/rewards2/ensemble_' + dataset_name + '_initial_pairs_1_num_queries_3_num_iter_1_retrain_num_iter_1_voi_dis_seed_101371_round_num_1.npy',
        # allow_pickle=True)
    rewards = np.load(
        '/data3/zj/tmp/oprl/reward_learning/rewards/ensemble_' + config.offline_dataset_name + '_initial_pairs_1_num_queries_1_num_iter_20_retrain_num_iter_20_voi_dis_seed_1011320_round_num_3.npy',
        allow_pickle=True)
    rewards-=np.min(rewards)
    # rewards = np.random.rand(*rewards.shape)*2-1
    # dataset_name = config.offline_dataset_name.split('_')[1]
    # opal_file ='./opal_rewards/ensemble_'+dataset_name+'_initial_pairs_1_num_queries_5_num_iter_20_retrain_num_iter_20_voi_dis_seed_285_round_num_9.npy'
    # opal_reward = np.load(opal_file,allow_pickle=True)
    dataset = dataset._replace(reward= rewards.astype(np.float32))
    # dataset = dataset._replace(reward=np.zeros_like(dataset.reward))
    # dataset = dataset._replace(reward=np.random.rand(*dataset.reward.shape)*2-1)
    print('use dataset reward?',config.use_dataset_reward,np.mean(dataset.reward))
    # Create dataset iterator for the relabeled dataset
    key = jax.random.PRNGKey(config.seed)
    key_learner, key_demo, key = jax.random.split(key, 3)
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))
    iterator = dataset_utils.JaxInMemorySampler(dataset, key_demo,
                                                                                            config.batch_size)
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))
    # Create an environment and grab the spec.
    environment = dataset_utils.make_environment(
            offline_dataset_name, seed=config.seed)
    # Create the networks to optimize.
    spec = acme.make_environment_spec(environment)
    networks = iql.make_networks(
            spec, hidden_dims=config.hidden_dims, dropout_rate=config.dropout_rate)

    counter = counting.Counter(time_delta=0.0)
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))
    if config.opt_decay_schedule == "cosine":
        schedule_fn = optax.cosine_decay_schedule(-config.actor_lr,
                                                                                            config.max_steps)
        policy_optimizer = optax.chain(optax.scale_by_adam(),
                                                                     optax.scale_by_schedule(schedule_fn))
    else:
        policy_optimizer = optax.adam(config.actor_lr)

    # Create the learner.
    learner_counter = counting.Counter(counter, "learner", time_delta=0.0)
    learner = iql.IQLLearner(
            networks=networks,
            random_key=key_learner,
            dataset=iterator,
            policy_optimizer=policy_optimizer,
            critic_optimizer=optax.adam(config.critic_lr),
            value_optimizer=optax.adam(config.value_lr),
            **config.iql_kwargs,
            logger=logger_factory('learner', learner_counter.get_steps_key(), 0),
            counter=learner_counter,
        expectile_mask=expectile_mask
    )
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))
    def evaluator_network(params, key, observation):
        del key
        action_distribution = networks.policy_network.apply(
                params, observation, is_training=False)
        return action_distribution.mode()

    eval_actor = actors.GenericActor(
            actor_core_lib.batched_feed_forward_to_actor_core(evaluator_network),
            random_key=key,
            variable_client=variable_utils.VariableClient(
                    learner, "policy", device="cpu"),
            backend="cpu",
    )
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))
    eval_counter = counting.Counter(counter, "eval_loop", time_delta=0.0)
    eval_loop = evaluation.D4RLEvalLoop(
            environment,
            eval_actor,
            counter=eval_counter,
            logger=logger_factory('eval_loop', eval_counter.get_steps_key(), 0),
    )
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))
    # Run the environment loop.
    steps = 0
    while steps < config.max_steps:
        for _ in range(config.evaluate_every):
            learner.step()
        steps += config.evaluate_every
        eval_loop.run(config.evaluation_episodes)
    print('use dataset reward?', config.use_dataset_reward, np.mean(dataset.reward))

if __name__ == '__main__':
    app.run(main)
