import glob
import os
import time
import warnings
from pathlib import Path

import d4rl
import gym
import numpy as np
import torch
from lightATAC.bp import BehaviorPretraining
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange

from CQL.SimpleSAC.utils import Timer
from iql.src.util import Log, return_range, set_seed
from src.add_lambda_heuristic import add_lambda
from src.gate import logi_gate
from src.hu_atac import ATAC
from src.hu_cql import ConservativeSAC
from src.hu_iql import ImplicitQLearning
from src.hu_td3_bc import TD3_BC
from src.util import (DEFAULT_DEVICE, evaluate_policy, sample_batch, torchify,
                      traj_data_to_qlearning_data, tuple_to_traj_data)
from src.value_functions import TwinQ, ValueFunction


def get_env_and_dataset(env_name, max_episode_steps, discount):
    if 'hdf5' in env_name:
        from src.mw_util import get_mw_env_and_data
        env, dataset, env_name = get_mw_env_and_data(dataset_path=env_name, discount=discount,  max_episode_steps=max_episode_steps)
    else:
        env = gym.make(env_name)
        # If the d4rl server is down, wait.
        warnings.filterwarnings("error")
        warnings.filterwarnings("ignore", category = DeprecationWarning)
        warnings.filterwarnings("ignore", category = ResourceWarning)
        for _ in range(100):
            try:
                dataset = env.get_dataset()
                break
            except:
                time.wait(300)
        warnings.filterwarnings("default")

        if any(s in env_name for s in ('halfcheetah', 'hopper', 'walker2d')):
            dataset_q = d4rl.qlearning_dataset(env)
            min_ret, max_ret = return_range(dataset_q, max_episode_steps)
            dataset['rewards'] /= (max_ret - min_ret)
            dataset['rewards'] *= max_episode_steps
            del dataset_q
        elif 'antmaze' in env_name:
            dataset['rewards'] -= 1.
    return env, dataset, env_name




def eval_agent(*, env, agent, discount, n_eval_episodes, max_episode_steps=1000,
               deterministic_eval=True, normalize_score=None):

    all_returns = np.array([evaluate_policy(env, agent, max_episode_steps, deterministic_eval, discount) \
                             for _ in range(n_eval_episodes)])
    eval_returns = all_returns[:,0]
    discount_returns = all_returns[:,1]
    success_not = all_returns[:,2]
    n_steps_test = all_returns[:,3]

    info_dict = {
        "return mean": eval_returns.mean(),
        "return std": eval_returns.std(),
        "discounted returns": discount_returns.mean(),
        "success rates": success_not.mean(),
        "n steps test":n_steps_test.mean(),
    }

    if normalize_score is not None:
        normalized_returns = normalize_score(eval_returns)
        info_dict["normalized return mean"] = normalized_returns.mean()
        info_dict["normalized return std"] =  normalized_returns.std()
    return info_dict



def main(args):
    # ------------------ Initialization ------------------ #
    print(args.env_name)
    torch.set_num_threads(1)
    env, dataset, env_name = get_env_and_dataset(args.env_name, args.max_episode_steps, args.discount)
    if env_name == args.env_name:
        re_folder_name = env_name

    else:
        file_name = os.path.basename(args.env_name)
        # file name without extension
        re_folder_name= os.path.splitext(file_name)[0]

        data_success_rates = dataset['successes'].mean()
        data_n_step_rewards = dataset['n_step_rewards'].mean()
        data_n_step_discount_rewards = dataset['n_step_discount_rewards'].mean()
        del dataset['successes']
        del dataset['n_step_rewards']
        del dataset['n_step_discount_rewards']


    log_path = Path(args.log_dir) / re_folder_name / (args.base_method +args.lambda_method+args.method+'heuristic_discount' +
        str(args.heuristic_discount)+'warmstart'+str(args.n_warmstart_steps)+'alpha'+str(args.alpha))
    log = Log(log_path, vars(args))
    log(f'Log dir: {log.dir}')
    writer = SummaryWriter(log.dir)

    obs_dim = dataset['observations'].shape[1]
    act_dim = dataset['actions'].shape[1]   # this assume continuous actions
    max_action = float(env.action_space.high[0])
    set_seed(args.seed, env=env)





    if args.base_method == 'iql':
        td_weight=1.0; rs_weight=0.0; fixed_alpha=0
        from src.policy import DeterministicPolicy, GaussianPolicy
        Policy = DeterministicPolicy if args.deterministic_policy else GaussianPolicy
        policy = Policy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden)
        rl = ImplicitQLearning(
            qf=TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden),
            vf=ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden),
            policy=policy,
            gate = logi_gate(obs_dim+act_dim, 1),
            optimizer_factory=lambda params: torch.optim.Adam(params, lr=args.learning_rate),
            max_steps=args.n_steps,
            tau=args.tau,
            beta=args.beta,
            alpha=args.alpha,
            discount=args.discount,
            temperature=args.temperature,
            method = args.method,
            gate_threshold = args.gate_threshold
        )
        networks = dict(qf=rl.qf, vf=rl.vf, policy=rl.policy)
        agent = policy # for evaluation.

    elif args.base_method=='td3_bc':
        td_weight=1.0; rs_weight=0.0; fixed_alpha=0
        rl = TD3_BC(state_dim = obs_dim,
        action_dim = act_dim,
        max_action = max_action,
        temperature = args.temperature,
        method=args.method,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        noise_clip=0.5,
        policy_freq=2,
        alpha=args.alpha,
        lr=args.learning_rate)
        vf = ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden).to(DEFAULT_DEVICE)
        qf = lambda observations, actions: sum(rl.critic(observations, actions)).squeeze()/2.0
        policy = lambda observations: rl.actor(observations)
        networks = dict(qf=qf, vf=vf, policy=policy)
        agent = rl # for evaluation.

    elif args.base_method=='atac' or args.base_method=='bc':

        td_weight=0.5; rs_weight=0.5; fixed_alpha=None

        from lightATAC.policy import GaussianPolicy

        qf = TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden).to(DEFAULT_DEVICE)
        vf = ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden).to(DEFAULT_DEVICE)
        policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden,
                                use_tanh=True, std_type='diagonal').to(DEFAULT_DEVICE)
        dataset['actions'] = np.clip(dataset['actions'], -1+1e-6, 1-1e-6)  # due to tanh
        rl = ATAC(
            policy=policy,
            qf=qf,
            optimizer=torch.optim.Adam,
            discount=args.discount,
            action_shape=act_dim,
            buffer_batch_size=args.batch_size,
            target_update_tau=5e-3,
            policy_lr=5e-7,
            qf_lr=5e-4,
            # ATAC parameters
            beta=args.beta,
            heuristic_method=args.method,
            heuristic_temperature=args.temperature,
        )
        rl.to(DEFAULT_DEVICE)
        networks = dict(qf=qf, vf=vf, policy=policy)
        agent = policy # for evaluation.

    elif args.base_method == 'cql':
        td_weight=1.0; rs_weight=0.0; fixed_alpha=0
        rl = ConservativeSAC(state_dim = obs_dim,
        action_dim = act_dim,
        cql_min_q_weight = args.cql_min_q_weight,
        temperature = args.temperature,
        method=args.method,
        discount=0.99)
        vf = ValueFunction(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden).to(DEFAULT_DEVICE)
        qf = lambda observations, actions: sum(rl.qf1(observations, actions) + rl.qf2(observations, actions)).squeeze()/2.0
        policy = lambda observations: rl.policy(observations)[0]
        networks = dict(qf=qf, vf=vf, policy=policy)
        agent = rl
        dataset['actions'] = np.clip(dataset['actions'], -0.999, 0.999)
    else:
        raise NotImplementedError

    # ------------------ Pretraining ------------------ #
    # Train policy and value to fit the behavior data
    if args.n_warmstart_steps>0:
        if args.base_method =='atac' or args.base_method=='bc':
            bp = BehaviorPretraining(qf=networks['qf'], policy=networks['policy'], vf=networks['vf'],
                                 lr=args.learning_rate, discount=args.discount, lambd=0.99,
                                 td_weight=td_weight, rs_weight=rs_weight, fixed_alpha=fixed_alpha, action_shape=act_dim).to(DEFAULT_DEVICE)

            rl._target_qf = bp.target_qf
            dataset = bp.train(dataset, args.n_warmstart_steps, log_freq = 5000, log_fun= lambda x, i: print(i, x))  # This ensures "next_observations" is in `dataset`.
            traj_data = tuple_to_traj_data(dataset)
        else:
            # if not atac, no pretraining on q or policy.
            bp = BehaviorPretraining(vf=networks['vf'],
                                 lr=args.learning_rate, discount=args.discount, lambd=0.99,
                                 td_weight=td_weight, rs_weight=rs_weight, fixed_alpha=fixed_alpha, action_shape=act_dim).to(DEFAULT_DEVICE)
            dataset = bp.train(dataset, args.n_warmstart_steps, log_freq = 5000, log_fun= lambda x, i: print(i, x))  # This ensures "next_observations" is in `dataset`.
            traj_data = tuple_to_traj_data(dataset)

        for data in traj_data:  # Update the return based on the learned vf
            vs = networks['vf'](torchify(data['observations'])).to('cpu').detach().numpy()
            data['returns'] =  vs

    else:
        traj_data = tuple_to_traj_data(dataset)
        BehaviorPretraining.preprocess_traj_data(traj_data, discount=args.discount)  # based on truncated MC estiamtes

    # Compute heuristic lambda value
    add_lambda(traj_data, args.heuristic_discount, args.discount, args.lambda_method)

    # Setup Evaluation
    def normalize_score(returns):  # for evaluation
        try:
            return d4rl.get_normalized_score(args.env_name, returns) * 100.0
        except:
            return returns

    # Main Training
    dataset = traj_data_to_qlearning_data(traj_data)
    for step in trange(args.n_steps):
        if (step==0 and args.base_method=='bc'):
            eval_metrics = eval_agent(env=env,
                                      agent=agent,
                                      max_episode_steps=args.max_episode_steps,
                                      discount=args.discount,
                                      n_eval_episodes=args.n_eval_episodes,
                                      normalize_score=normalize_score)
            # if it is mw, we also want to log out some mw data results.
            if env_name!=args.env_name:
                eval_metrics['data_success_rates'] = data_success_rates
                eval_metrics['data_n_step_rewards'] = data_n_step_rewards
                eval_metrics['data_n_step_discount_rewards'] = data_n_step_discount_rewards

            log.row(eval_metrics)

            for k, v in eval_metrics.items():
                writer.add_scalar('Eval/'+k, v, step)
        train_metrics = rl.update(**sample_batch(dataset, args.batch_size))
        if (step+1) % max(int(args.eval_period/10),1) == 0:
            for k, v in train_metrics.items():
                writer.add_scalar('Train/'+k, v, step)
        if (step+1) % args.eval_period == 0 or (step==0 and args.base_method=='bc'):
            eval_metrics = eval_agent(env=env,
                                      agent=agent,
                                      max_episode_steps=args.max_episode_steps,
                                      discount=args.discount,
                                      n_eval_episodes=args.n_eval_episodes,
                                      normalize_score=normalize_score)
            # if it is mw, we also want to log out some mw data results.
            if env_name!=args.env_name:
                eval_metrics['data_success_rates'] = data_success_rates
                eval_metrics['data_n_step_rewards'] = data_n_step_rewards
                eval_metrics['data_n_step_discount_rewards'] = data_n_step_discount_rewards

            log.row(eval_metrics)

            for k, v in eval_metrics.items():
                writer.add_scalar('Eval/'+k, v, step)
    torch.save(rl.state_dict(), log.dir/'final.pt')

    log.close()
    writer.close()
    return eval_metrics['normalized return mean']



if __name__ == '__main__':
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('--env-name', required=True)
    parser.add_argument('--log-dir', required=True)
    parser.add_argument('--method', type =str, default = 'softmax')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--hidden-dim', type=int, default=256)
    parser.add_argument('--n-hidden', type=int, default=2)
    parser.add_argument('--n-steps', type=int, default=10**6)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--learning-rate', type=float, default=3e-4)
    parser.add_argument('--alpha', type=float, default=0.005)
    parser.add_argument('--tau', type=float, default=0.7)
    parser.add_argument('--beta', type=float, default=3.0)
    parser.add_argument('--expectile', type=float, default=0.5)
    parser.add_argument('--deterministic-policy', action='store_true')
    parser.add_argument('--eval-period', type=int, default=5000)
    parser.add_argument('--base-method', type=str, default='iql')
    parser.add_argument('--n-eval-episodes', type=int, default=10)
    parser.add_argument('--max-episode-steps', type=int, default=1000)
    parser.add_argument('--temperature', type=float, default=1)
    parser.add_argument('--heuristic-discount', type=float, default=1)
    parser.add_argument('--n-warmstart-steps', type=int, default=0)
    parser.add_argument('--gate-threshold', type=int, default=0)
    parser.add_argument('--lambda-method', type=str, default='None')
    parser.add_argument('--root-dir', type=str, default='None')
    parser.add_argument('--cql-min-q-weight', type=float, default=5.0)
    main(parser.parse_args())