from pathlib import Path
import gym, d4rl, os, copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from ATAC.policy import GaussianPolicy
from ATAC.value_functions import TwinQ
from ATAC.util import Log, set_seed
from ATAC.util import evaluate_policy, sample_batch, traj_data_to_qlearning_data, tuple_to_traj_data, DEFAULT_DEVICE, cat_data_dicts, mlp
from armor.bp import BehaviorPretraining
from armor.util import to_transition_batch
from armor.gaussian_mlp_world_model import GaussianMLPWorldModel
from armor.one_dim_trt_model import OneDTransitionRewardTerminationModel
from armor.simple_armor import SimpleARMOR
from urllib.error import HTTPError

EPS=1e-6



def eval_agent(*, env, agent, discount, n_eval_episodes, max_episode_steps=1000,
               deterministic_eval=True, normalize_score=None):

    all_returns = np.array([evaluate_policy(env, agent, max_episode_steps, deterministic_eval, discount) \
                             for _ in range(n_eval_episodes)])
    eval_returns = all_returns[:,0]
    discount_returns = all_returns[:,1]

    info_dict = {
        "return mean": eval_returns.mean(),
        "return std": eval_returns.std(),
        "discounted returns": discount_returns.mean()
    }

    if normalize_score is not None:
        normalized_returns = normalize_score(eval_returns)
        info_dict["normalized return mean"] = normalized_returns.mean()
        info_dict["normalized return std"] =  normalized_returns.std()
    return info_dict

def get_dataset(env):
    while True:
        try:
            dataset = env.get_dataset()
        except (HTTPError, OSError):
            print('Unable to download dataset. Retry.')
        return traj_data_to_qlearning_data(tuple_to_traj_data(dataset))  # make sure next_observation is added

def get_env_and_dataset(env_name):
    env = gym.make(env_name)  # d4rl ENV
    dataset = get_dataset(env)
    # process rewards such that V(done)=0 is correct.
    if  env_name in ('kitchen-complete-v0', 'kitchen-partial-v0', 'kitchen-mixed-v0'):
        assert len(env.TASK_ELEMENTS) >= dataset['rewards'].max()
        assert env.TERMINATE_ON_TASK_COMPLETE
        dataset['rewards'] -= len(env.TASK_ELEMENTS)
        # fix terminal issue
        traj_data = tuple_to_traj_data(dataset)
        for traj in traj_data:
            traj['terminals'] = traj['rewards']==0
            traj['timeouts'] = np.zeros_like(traj['timeouts'], dtype=bool)
            traj['timeouts'][-1] = not traj['terminals'][-1]
        dataset = traj_data_to_qlearning_data(traj_data)

    return env, dataset

def succesfully_load_model(model, path):
    try:
        model.load_state_dict(torch.load(path))
        return True
    except (FileNotFoundError, RuntimeError):
        return False


def main(args):
    # ------------------ Initialization ------------------ #
    torch.set_num_threads(4)
    env, dataset = get_env_and_dataset(args.env_name)
    set_seed(args.seed, env=env)
    torch_rng = torch.Generator(device=DEFAULT_DEVICE)

    # Parse and process arguments
    exp_name = args.env_name if not args.policy_improvement_exp else args.env_name+'_PI'
    args.use_data_terminals = args.use_data_terminals or args.not_learn_termination
    args.learn_termination = not args.not_learn_termination
    if type(args.margin)==str:
        args.margin = float(args.margin)

    log_path = Path(args.log_dir) / args.env_name / \
                   ('beta' + str(args.beta) +  \
                    '_reg' + str(args.reg_coeff) + \
                    # model rollout and learning
                    '_horizon' + str(args.rollout_horizon) + \
                    '_t_batch_size' + str(args.terminal_batch_size)+ \
                    '_m_batch_size' + str(args.model_batch_size) + \
                    '_m_buffer_size' + str(args.model_buffer_size) + \
                     # value learning
                    '_margin'+str(args.margin) + \
                    '_clip_v'+str(args.clip_v) + \
                     # model spec
                    '_det' + str(args.deterministic_model) + \
                    '_e_size' + str(args.model_ensemble_size) + \
                    '_m_hidden_dim' + str(args.model_hidden_dim) + \
                    '_m_n_layers' + str(args.model_num_layers) + \
                    '_tloss_weight' +str(args.term_loss_weight) + \
                     # debug
                    '_learn_term' + str(args.learn_termination) + \
                    '_hybrid' + str(args.hybrid)  + \
                    '_ignore_model' + str(args.ignore_model_prediction) + \
                    '_atac_mode'  + str(args.atac_mode) + \
                    '_no_model_grad' + str(args.no_model_grad)
                   )

    # Log the config
    log = Log(log_path, vars(args))
    log(f'Log dir: {log.dir}')
    writer = SummaryWriter(log.dir)

    # Set up model saving directory
    if args.model_dir is None or args.model_dir in ('None', 'none'):
        model_dir = log.dir  # just save to the experiment directory
    else:  # save under model_dir/exp_name/
        model_dir = os.path.join(args.model_dir, exp_name)
        os.makedirs(model_dir, exist_ok=True)



    # Assume vector observation and action
    obs_dim, act_dim = dataset['observations'].shape[1], dataset['actions'].shape[1]
    obs_max, obs_min = np.amax(dataset['observations'], axis=0), np.amin(dataset['observations'], axis=0)
    dataset['actions'] = np.clip(dataset['actions'], -1+EPS, 1-EPS)  # due to tanh
    dataset = {k:v for k,v in dataset.items() if not any([ig in k for ig in ("metadata","info",)])}  # remove some fields
    Vmin = min(0,dataset['rewards'].min() * 1/(1-args.discount)) if args.clip_v else -float('inf')
    Vmax = max(0,dataset['rewards'].max() * 1/(1-args.discount)) if args.clip_v else  float('inf')

    #region ============== Initialize the MDP model ==============
    model = GaussianMLPWorldModel(
        obs_dim + act_dim, # input size
        obs_dim + int(args.learn_rewards), # output size
        device=DEFAULT_DEVICE,
        num_layers=args.model_num_layers,
        hid_size=args.model_hidden_dim,
        activation_fn_cfg={"_target_": "torch.nn.SiLU"},
        ensemble_size=args.model_ensemble_size,
        deterministic=args.deterministic_model,
        propagation_method=args.model_propagation_method,
        learn_logvar_bounds=args.model_learn_logvar_bounds,
        learn_termination=args.learn_termination,
        term_loss_weight=args.term_loss_weight,
        )

    model_wrapper = OneDTransitionRewardTerminationModel(
        model,
        target_is_delta=args.model_target_is_delta,
        normalize=True,
        learned_rewards=args.learn_rewards,
        learned_terminals=args.learn_termination,
        normalize_double_precision=False)

    model_wrapper.update_normalizer(to_transition_batch(dataset))
    #endregion

    #region ============== Initialize policy and qfs ==============
    qf = TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden).to(DEFAULT_DEVICE)
    target_qf = copy.deepcopy(qf).requires_grad_(False)
    policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden, use_tanh=True, std_type='diagonal').to(DEFAULT_DEVICE)
    ref_policy = None
    if args.policy_improvement_exp:
        #train reference policy on expert data
        ref_policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden, use_tanh=True, std_type='diagonal').to(DEFAULT_DEVICE)
        version = args.env_name.split('-')[-1]
        name = args.env_name.split('-')[0]
        expert_env_name = name + '-expert-' + version
        expert_env = gym.make(expert_env_name)
        expert_dataset = get_dataset(expert_env)
        expert_dataset['actions'] = np.clip(expert_dataset['actions'], -1+EPS, 1-EPS)  # due to tanh
        expert_dataset = {k:v for k,v in expert_dataset.items() if not any([ig in k for ig in ("metadata","info",)])}  # remove some fields

        def expert_bp_log_fun(metrics, step):
            print(step, metrics)
            for k, v in metrics.items():
                writer.add_scalar('ExpertBehaviorPretraining/'+k, v, step)
        ebp = BehaviorPretraining(policy=ref_policy, lr=args.fast_lr).to(DEFAULT_DEVICE)
        ebp_path = os.path.join(model_dir, 'ebp.pt')
        if not succesfully_load_model(ebp, ebp_path):
            _ = ebp.train(expert_dataset, args.n_warmstart_steps, log_fun=expert_bp_log_fun)
            if args.save_models:
                torch.save(ebp.state_dict(), ebp_path)

    #region ============== Pretrain ==============
    bp = BehaviorPretraining(qf=qf if not args.policy_improvement_exp else None, policy=policy,
                             target_qf=target_qf if not args.policy_improvement_exp else None, reference=ref_policy, model=model_wrapper,
                             lr=args.fast_lr, discount=args.discount, Vmax=Vmax, Vmin=Vmin,
                             td_weight=args.wt, rs_weight=args.ws, fixed_alpha=None, action_shape=act_dim).to(DEFAULT_DEVICE)
    if args.policy_improvement_exp:  # This is okay, as there is no optimizer initialized in the previous step.
        bp.qf = qf; bp.target_qf = target_qf

    learning_from_scratch = args.learning_from_scratch
    model_filename = 'mdp_model'+('_det' if args.deterministic_model else '_sto') + '_'+str(args.model_ensemble_size) + \
                     ('t' if args.learn_termination else '')
    mdp_model_path = os.path.join(model_dir, model_filename+'.pt')
    bp_model_path = os.path.join(model_dir, args.bp_filename)

    print('Loading models...')
    learning_from_scratch=args.learning_from_scratch
    if not learning_from_scratch:
        if not succesfully_load_model(bp, bp_model_path):
            try: # Try to fix it manually (this is for local debugging mostly, for testing different model parameters.)
                # mdp model is loaded successfully
                model_wrapper.load_state_dict(torch.load(mdp_model_path))
                model_dict = model_wrapper.state_dict()
                bp_dict = torch.load(bp_model_path)  # XXX hacking now; TODO find a better way of saving
                bp_dict.update({'_model.'+k:v for k,v in model_dict.items() })
                bp_dict.update({'model.'+k:v for k,v in model_dict.items() })
                if args.policy_improvement_exp:
                    ebp.reference = ebp.policy  # need to set the reference from ebp
                    for k, v in  ebp.state_dict().items():
                        if 'reference' in k:
                            bp_dict[k] = v
                bp.load_state_dict(bp_dict)
            except:  # Train from scratch after exhaustingall attemps.
                learning_from_scratch = True

    if learning_from_scratch:
        print('Models not loaded. Learning from scratch...')
        def bp_log_fun(metrics, step):
            print(step, metrics)
            for k, v in metrics.items():
                writer.add_scalar('BehaviorPretraining/'+k, v, step)
            if step % args.pretrain_checkpoint_freq ==0 and args.save_models and step>0:
                torch.save(bp.state_dict(), os.path.join(model_dir, 'bp_'+str(step)+'.pt'))
        _ = bp.train(dataset, args.n_warmstart_steps, log_fun=bp_log_fun)
        if args.save_models:
            torch.save(bp.state_dict(), os.path.join(model_dir,'bp.pt'))
            torch.save(bp._model.state_dict(), mdp_model_path)
    #endregion

    #region ============== Initialize ARMOR ==============
    if ref_policy is None:
        print('No reference is found. Use the BC policy as the reference.')
        ref_policy = copy.deepcopy(bp.policy) # At this point, bp is set up.
    ref_policy.requires_grad_(False)
    traj_data = tuple_to_traj_data(dataset)
    init_observations = np.vstack([traj['observations'][0] for traj in traj_data]) # NOTE Do we want to include other observations?
    armor = SimpleARMOR(
        policy=policy,
        qf=qf,
        target_qf=target_qf,
        optimizer=torch.optim.Adam,
        discount=args.discount,
        action_shape=act_dim,
        obs_highs=torch.tensor(obs_max, device=DEFAULT_DEVICE),
        obs_lows=torch.tensor(obs_min, device=DEFAULT_DEVICE),
        buffer_batch_size=args.batch_size,
        policy_lr=args.slow_lr,
        qf_lr=args.fast_lr,
        # ATAC main parameters
        beta=args.beta, # the regularization coefficient in front of the Bellman error
        # Armor parameters
        reference=ref_policy,
        model=model_wrapper,
        model_lr=args.fast_lr,
        rollout_horizon=args.rollout_horizon,
        model_buffer_size=args.model_buffer_size,
        model_batch_size=args.model_batch_size,
        reg_coeff=args.reg_coeff,
        rng=torch_rng,
        wt=args.wt,
        ws=args.ws,
        bellman_model_grad_type=args.bellman_model_grad_type,
        use_data_terminals=args.use_data_terminals,
        ignore_model_prediction=args.ignore_model_prediction,
        use_model_grad= not args.no_model_grad,
        hybrid=args.hybrid,
        atac_mode=args.atac_mode,
        margin=args.margin,
        Vmax=Vmax, Vmin=Vmin,
    ).to(DEFAULT_DEVICE)
    #endregion

    del dataset['timeouts']  # not used
    terminal_dataset = {k:v[dataset['terminals']]  for k,v in dataset.items()}
    if len(terminal_dataset['rewards'])<1:
        terminal_dataset = None

    # Main Training
    for step in trange(args.n_steps):
        # Evaluation
        if step % args.eval_period == 0 or step==args.n_steps-1:
            eval_metrics = eval_agent(
                env=env, agent=policy, discount=args.discount, n_eval_episodes=args.n_eval_episodes,
                normalize_score=lambda returns: d4rl.get_normalized_score(args.env_name, returns)*100.0)
            log.row(eval_metrics)
            for k, v in eval_metrics.items():
                writer.add_scalar('Eval/'+k, v, step)
        # Update
        batch = sample_batch(dataset, args.batch_size)
        if terminal_dataset is not None:
            terminal_batch = sample_batch(terminal_dataset, args.terminal_batch_size)
            batch = cat_data_dicts(batch, terminal_batch)
        train_metrics = armor.update(**batch)
        # Logging
        if step % max(int(args.eval_period/10),1) == 0 or step==args.n_steps-1:
            print(train_metrics)
            for k, v in train_metrics.items():
                writer.add_scalar('Train/'+k, v, step)

    # Final processing
    torch.save(armor.state_dict(), log.dir/'final.pt')
    log.close()
    writer.close()
    return eval_metrics['normalized return mean']


def get_parser():
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('--env_name', required=True)
    parser.add_argument('--log_dir', required=True)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--eval_period', type=int, default=5000)
    parser.add_argument('--n_eval_episodes', type=int, default=10)
    # optimization parameters
    parser.add_argument('--n_warmstart_steps', type=int, default=100*10**3)
    parser.add_argument('--n_steps', type=int, default=10**6)
    parser.add_argument('--batch_size', type=int, default=125)
    parser.add_argument('--model_batch_size', type=int, default=125)
    parser.add_argument('--terminal_batch_size', type=int, default=0)
    parser.add_argument('--fast_lr', type=float, default=5e-4)
    parser.add_argument('--slow_lr', type=float, default=5e-7)
    # Armor parameters
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--ws', type=float, default=0.5)
    parser.add_argument('--wt', type=float, default=0.5)
    parser.add_argument('--margin', type=float, default=float('inf'))
    parser.add_argument('--clip_v', action='store_true')
    parser.add_argument('--reg_coeff', type=float, default=100.0)
    parser.add_argument('--rollout_horizon', type=int, default=20)
    parser.add_argument('--model_buffer_size', type=int, default=10**6)
    # Armor config
    parser.add_argument('--model_dir', type=str, default=None)
    parser.add_argument('--bp_filename', type=str, default='bp.pt')  
    parser.add_argument('--save_models', action='store_true')
    parser.add_argument('--learning_from_scratch', action='store_true')
    parser.add_argument('--policy_improvement_exp', action='store_true')
    parser.add_argument('--pretrain_checkpoint_freq',type=int, default=500000)  
    # Policy, qf parameters
    parser.add_argument('--hidden_dim', type=int, default=256)
    parser.add_argument('--n_hidden', type=int, default=3)
    # Model parameters
    parser.add_argument('--model_num_layers', type=int, default=3)
    parser.add_argument('--model_hidden_dim', type=int, default=512)
    parser.add_argument('--model_ensemble_size', type=int, default=1)
    parser.add_argument('--term_loss_weight', type=float, default=100.0)
    # Debug flags: Below should be considered fixed.
    parser.add_argument('--bellman_model_grad_type', type=str, default='tdrs')
    parser.add_argument('--deterministic_model', action='store_true')
    parser.add_argument('--model_target_is_delta', type=bool, default=True)
    parser.add_argument('--model_propagation_method', type=str, default='expectation')
    parser.add_argument('--model_learn_logvar_bounds', type=bool, default=True)
    parser.add_argument('--learn_rewards', type=bool, default=True)
    parser.add_argument('--not_learn_termination', action='store_true') 
    parser.add_argument('--use_data_terminals', action='store_true') 
    parser.add_argument('--ignore_model_prediction', action='store_true') 
    parser.add_argument('--no_model_grad', action='store_true')
    parser.add_argument('--hybrid', action='store_true')
    parser.add_argument('--atac_mode', action='store_true') 
    return parser


if __name__ == '__main__':
    parser = get_parser()
    main(parser.parse_args())
