import hashlib
import importlib
from argparse import ArgumentParser
import os
import gym
import d4rl
import numpy as np
from ml_collections import ConfigDict
import cloudpickle as pickle
from JaxCQL.conservative_sac import ConservativeSAC, BehaviorCloning
from JaxCQL.jax_utils import batch_to_jax
from JaxCQL.model import TanhGaussianPolicy, FullyConnectedQFunction, SamplerPolicy
from JaxCQL.replay_buffer import get_d4rl_dataset, ReplayBuffer, PrioritizedReplayBuffer
from JaxCQL.sampler import TrajSampler, StepSampler
from JaxCQL.utils import (
    Timer, set_random_seed, prefix_metrics
)
from cql_finetune.utils import MyWandBLogger
from logger.logging import logger, setup_logger
from tqdm import tqdm
import optax
from flax.training import train_state
from flax.serialization import to_bytes, from_bytes
import jax.numpy as jnp
import jax


def encode(name: str, length=64):
    m = hashlib.md5()
    m.update(str.encode(name))
    return m.hexdigest()[:length]


def get_config():
    parser = ArgumentParser()
    parser.add_argument('-c', '--config-name', type=str, required=True, help='The name of config files, e.g., base')
    parser.add_argument('--env', type=str, default='halfcheetah-medium-v2')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save-model', action='store_true')
    parser.add_argument('--load-model', type=str)
    parser.add_argument('--wandb-offline', action='store_true')
    parser.add_argument('--wandb-project', type=str)
    parser.add_argument('--wandb-entity', type=str)
    parser.add_argument('--wandb-output-dir', type=str, default='./experiment_output')
    parser.add_argument('--gamma', type=float)
    parser.add_argument('--l', type=float)

    args = parser.parse_args()

    config_module = importlib.import_module(
        f'cql_finetune.finetune_configs.{args.config_name}')
    config = config_module.config
    config.update(vars(args))
    config = ConfigDict(config, convert_dict=True)

    exp_encoding = encode(f'CQL_finetune_{config.env}_{config.config_name}')
    exp_name = f'seed_{config.seed}_{config.env}_{exp_encoding}'
    group_name = f"{config.env}_{exp_encoding}"

    config.logging = MyWandBLogger.get_default_config(updates=dict(
        online=not config.wandb_offline,
        exp_name=exp_name,
        group_name=group_name,
        project=config.wandb_project,
        entity=config.wandb_entity,
        output_dir=config.wandb_output_dir,
    ))

    return config

def load_checkpoint(checkpoint_dir, model, optimizer):
    with open(checkpoint_dir, 'rb') as f:
        checkpoint_data = f.read()
    params = from_bytes(None, checkpoint_data)['params']
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
    return state

@jax.jit
def compute_sample_weights(agent_state, batch, alpha, l):
    # Vectorized action calculation using JAX
    actions_bc = agent_state.apply_fn({'params': agent_state.params["params"]}, batch['observations'])
    priorities = jnp.power(jnp.mean(jnp.square(actions_bc - batch['actions']), axis=1)/l + 1, alpha)

    return priorities

@jax.jit
def compute_q_weights(agent_state, batch, gamma):
    # Vectorized action calculation using JAX
    actions_bc = agent_state.apply_fn({'params': agent_state.params["params"]}, batch['observations'])
    action_error = jnp.mean(jnp.square(actions_bc - batch['actions']), axis=1) / gamma
    sample_weights = jnp.exp(-action_error)
    
    return sample_weights

def main(config):
    os.makedirs(config.wandb_output_dir, exist_ok=True)

    set_random_seed(config.seed)

    eval_sampler = TrajSampler(gym.make(config.env).unwrapped, config.max_traj_length)
    expl_sampler = StepSampler(gym.make(config.env).unwrapped, config.max_traj_length)

    dataset = get_d4rl_dataset(eval_sampler.env)
    dataset['rewards'] = dataset['rewards'] * config.reward_scale + config.reward_bias
    dataset['actions'] = np.clip(dataset['actions'], -config.clip_action, config.clip_action)
    # convert dataset to replay_buffer
    replay_buffer = PrioritizedReplayBuffer(max_size=config.replay_buffer_max_size, data=dataset,
                                 reward_scale=config.reward_scale, reward_bias=config.reward_bias)

    observation_dim = eval_sampler.env.observation_space.shape[0]
    action_dim = eval_sampler.env.action_space.shape[0]

    policy = TanhGaussianPolicy(
        observation_dim, action_dim, config.policy_arch, config.orthogonal_init,
        config.policy_log_std_multiplier, config.policy_log_std_offset
    )
    qf = FullyConnectedQFunction(observation_dim, action_dim, config.qf_arch, config.orthogonal_init)

    if config.cql.target_entropy >= 0.0:
        config.cql.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item()
    
    # load cql
    sac = ConservativeSAC(config.cql, policy, qf)
    with open(config.load_model, 'rb') as fin:
        saved_model = pickle.load(fin)['sac'].train_params

    sac.load(saved_model)
    
    # load bc
    agent_bc = BehaviorCloning(state_dim=observation_dim, action_dim=action_dim, hidden_dim=256)
    optimizer = optax.adamw(learning_rate=1e-3)
    checkpoint_dir = '../models/bc_model/'+config.env+'/checkpoints/checkpoint_2000000.ckpt'
    agent_state = load_checkpoint(checkpoint_dir, agent_bc, optimizer)

    sampler_policy = SamplerPolicy(sac.policy, sac.train_params['policy'])
    eval_returns = []

    is_offline = False
    use_cql = True
    i = 0
    for epoch in tqdm(range(8), ncols=80):
        with Timer() as train_timer:
            for train_loop_idx in range(config.num_train_loop_per_epoch):
                expl_samples = expl_sampler.sample(
                    sampler_policy.update_params(sac.train_params['policy']),
                    n_steps=config.num_expl_steps_per_train_loop,
                    deterministic=False
                )

                priorities = compute_sample_weights(agent_state, expl_samples, replay_buffer.alpha, config.l)
                replay_buffer.add_batch(expl_samples, priorities=priorities)

                for train_idx in range(config.num_trains_per_train_loop):
                    batch_s, indices = replay_buffer.sample(config.batch_size)
                    sample_weights = compute_q_weights(agent_state, batch_s, config.gamma)
                    batch = batch_to_jax(batch_s)
                    sac.train_our(batch, sample_weights, use_cql=use_cql, bc=False)

                    i += 1

                    if i > 0 and i % 2000 == 0:
                        trajs = eval_sampler.sample(
                            sampler_policy.update_params(sac.train_params['policy']),
                            config.eval_n_trajs, deterministic=True
                        )

                        eval_returns.append((i, np.mean([np.sum(t['rewards']) for t in trajs]), np.mean([eval_sampler.env.get_normalized_score(np.sum(t['rewards'])) for t in trajs])))

                        np.savetxt(os.path.join(config.wandb_output_dir, 'eval_returns.txt'),
                            eval_returns,
                            fmt=['%d', '%.1f', '%.4f'],)

if __name__ == '__main__':
    main(config=get_config())