import argparse
import os

import random
import gym
import torch
import d4rl
import numpy as np
import sys  
sys.path.append('../LEASE') # add path

from configs.iql import loaded_args

from dataset.load_offline_dataset import qlearning_dataset
from dataset.buffer import ReplayBuffer
from dataset.load_preference_dataset import load_preference_dataset
from dataset.generate_preference_data import collect_preference_data

from transition.network import EnsembleTransitionModel
from transition.transition_model import EnsembleTransition

from reward.reward_model import RewardModel

from policies.iql import IQLPolicy
from policies.models.nets import MLP
from policies.models.dist import DiagGaussian
from policies.models.actor_critic import ActorProb, Critic

from utils.scaler import StandardScaler
from utils.termination_fns import get_termination_fn
from utils.logger import Logger, make_log_dirs
from utils.policy_trainer import PolicyTrainer

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default = "walker2d-medium-expert-v2")
    parser.add_argument("--seed", type=int, default = 0)
    parser.add_argument("--device", type=str, default = "cuda" if torch.cuda.is_available() else "cpu")

    known_args, _ = parser.parse_known_args()
    default_args = loaded_args[known_args.task]
    for arg_key, default_value in default_args.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))

    return parser.parse_args()


def train(args=get_args()):
    # create env and dataset
    env = gym.make(args.task)
    dataset = qlearning_dataset(env)
    args.obs_shape = env.observation_space.shape
    args.action_dim = int(np.prod(env.action_space.shape))
    args.max_action = env.action_space.high[0]

    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    # create policy model
    actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims, dropout_rate=args.dropout_rate)
    critic_q1_backbone = MLP(input_dim=np.prod(args.obs_shape)+args.action_dim, hidden_dims=args.hidden_dims)
    critic_q2_backbone = MLP(input_dim=np.prod(args.obs_shape)+args.action_dim, hidden_dims=args.hidden_dims)
    critic_v_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
    dist = DiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=args.action_dim,
        unbounded=False,
        conditioned_sigma=False,
        max_mu=args.max_action
    )
    actor = ActorProb(actor_backbone, dist, args.device)
    critic_q1 = Critic(critic_q1_backbone, args.device)
    critic_q2 = Critic(critic_q2_backbone, args.device)
    critic_v = Critic(critic_v_backbone, args.device)
    
    for m in list(actor.modules()) + list(critic_q1.modules()) + list(critic_q2.modules()) + list(critic_v.modules()):
        if isinstance(m, torch.nn.Linear):
            # orthogonal initialization
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)

    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic_q1_optim = torch.optim.Adam(critic_q1.parameters(), lr=args.critic_q_lr)
    critic_q2_optim = torch.optim.Adam(critic_q2.parameters(), lr=args.critic_q_lr)
    critic_v_optim = torch.optim.Adam(critic_v.parameters(), lr=args.critic_v_lr)

    if args.lr_scheduler:
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch)
    else:
        lr_scheduler = None
    
    # create transition model
    load_transition_model = True if args.load_transition_path else False
    transition_model = EnsembleTransitionModel(
        obs_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        hidden_dims=args.transition_hidden_dims,
        num_transition_ensemble=args.n_transition_ensemble,
        num_elites=args.n_elites,
        weight_decays=args.transition_weight_decay,
        device=args.device
    )
    transition_optim = torch.optim.Adam(
        transition_model.parameters(),
        lr=args.transition_lr
    )
    scaler = StandardScaler()
    termination_fn = get_termination_fn(task=args.task)
    transition = EnsembleTransition(
        transition_model,
        transition_optim,
        scaler,
        termination_fn
    )
    if args.load_transition_path:
        transition.load(args.load_transition_path)

    # create reward model
    load_reward_model = True if args.load_reward_path else False
    reward = RewardModel(
        args=args,
        observation_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        num_reward_ensemble=args.n_reward_ensemble, 
        lr=args.reward_lr,
        activation="tanh", 
        device=args.device
    )

    if args.load_reward_path:
        reward.load_model(args.load_reward_path)

    # create IQL policy
    policy = IQLPolicy(
        transition,
        actor,
        critic_q1,
        critic_q2,
        critic_v,
        actor_optim,
        critic_q1_optim,
        critic_q2_optim,
        critic_v_optim,
        action_space=env.action_space,
        tau=args.tau,
        gamma=args.gamma,
        expectile=args.expectile,
        temperature=args.temperature
    )

   # log
    name = "LEASE" if args.update_reward else "FEWER"
    name = name if args.select_data else "FRESH"
    log_dirs = make_log_dirs(
        args.task, args.algo_name, name, args.seed, vars(args)
    )
    # key: output file name, value: output handler type
    output_config = {
        "consoleout_backup": "stdout",
        "reward_training_progress": "csv",
        "transition_training_progress": "csv",
        "policy_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))

    # create buffer
    real_buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    real_buffer.load_dataset(dataset)

    fake_buffer = ReplayBuffer(
        buffer_size=args.rollout_batch_size*args.rollout_length*args.model_retain_epochs,
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )

    # collect and load preference dataset
    if args.collect_preference_data:
        print("-----------------------------------------------")
        print("*******start collecting preference data********")
        print("-----------------------------------------------")
        collect_preference_data(args=args, dataset=dataset,num_query=args.num_query,len_query=args.len_query,human_label=False)
        offline_preference_dataset = load_preference_dataset(
            args=args, 
            dataset=dataset,
            num_query=args.num_query,
            len_query=args.len_query,
            human_label=False
            )
    else:
        offline_preference_dataset = load_preference_dataset(
            args=args, 
            dataset=dataset,
            num_query=args.num_query,
            len_query=args.len_query
            )
    
    # train transition model
    if not load_transition_model and args.update_reward:
        print("---------------------------------------------")
        print("*******start training transition model*******")
        print("---------------------------------------------")
        transition.train(
            real_buffer.sample_all(),
            logger,
            max_epochs_since_update=args.max_epochs_since_update,
            max_epochs=args.transition_max_epochs
        )

   # train reward model
    if not load_reward_model:
        print("---------------------------------------------")
        print("*******start pretraining reward model********")
        print("---------------------------------------------")
        reward.pretrain(
            init_pref_real_dataset=offline_preference_dataset,
            n_epochs= args.reward_pretrain_epoch,
            logger=logger,
            batch_size= args.reward_pretrain_batch_size
        )
    
    real_buffer.predict_reward( reward_model=reward )

    # train policy
    policy_trainer = PolicyTrainer(
        args=args,
        policy=policy,
        reward=reward,
        eval_env=env,
        real_buffer=real_buffer,
        fake_buffer=fake_buffer,
        logger=logger,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.policy_batch_size,
        lr_scheduler=lr_scheduler,
        rollout_freq=args.rollout_freq,
        num_query=args.rollout_batch_size*args.model_retain_epochs,
        len_query=args.rollout_length
    )
    print("------------------------------------------")
    print("*******start training policy model********")
    print("------------------------------------------")
    policy_trainer.train(offline_preference_dataset)

if __name__ == "__main__":
    train()