try:
    import babyai
except:
    print('BabyAI env is not installed')

try:
    import lorl_env
except:
    print('Lorl env not installed')
import gym
import d4rl
import numpy as np
import torch
import wandb
import datetime
import os
import ast
import hydra
import random
from common.logger import Logger
from omegaconf import DictConfig, OmegaConf
from common.dataset import ImagePairDataset

from transformers import DistilBertTokenizer
from torch.utils.data import DataLoader

from common.expert_dataset import ExpertDataset
from hrl_model import HRLModel
from trainer import Trainer


def evaluate(cfg):
    # load saved arguments
    checkpoint = torch.load(cfg.checkpoint_path)
    args = checkpoint['config']
    max_length = checkpoint['train_dataset_max_length']
    args.eval = cfg.eval
    args.render = cfg.render
    args.checkpoint_path = cfg.checkpoint_path
    device = cfg.trainer.device

    # Set num train_trajs to something small
    args.train_dataset.num_trajectories = 1000
    print(OmegaConf.to_yaml(args))

    args.method = args.model.name

    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    num_eval_episodes = args.trainer.num_eval_episodes
    print('=' * 50)
    print(f'Starting evaluation: {args.env.name}')
    print(f'{args.trainer.num_eval_episodes} trajectories')
    print('=' * 50)

    state_dim = args.env.state_dim
    action_dim = args.env.action_dim
    if isinstance(state_dim, str):
        state_dim = ast.literal_eval(state_dim)

    if isinstance(state_dim, tuple):
        assert not args.trainer.state_il, "Cannot do state imitation learning with an image input"

    if not args.env.eval_offline:
        env = gym.make(args.env.name)
        env_name = args.env.name
        env.seed(args.seed)

    if 'BabyAI' in args.env.name:
        state_dim += 4*args.env.use_direction

    train_dataset_args = dict(args.train_dataset)
    batch_size = args.batch_size

    if 'BabyAI' in args.env.name:
        train_dataset = ExpertDataset(**train_dataset_args, use_direction=args.env.use_direction)
    elif 'Lorl' in args.env.name:
        # train_dataset_args also contains a split here for the validation data size
        train_dataset = ExpertDataset(**train_dataset_args, use_state=args.env.use_state)
    elif 'kitchen' or 'grid' in args.env_name:
        train_dataset = ImagePairDataset(train_dataset_args["expert_location"])
    else:
        raise NotImplementedError
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=True, drop_last=True, num_workers=4)

    if args.method == 'traj_option':
        args.option_selector.option_transformer.max_length = int(max_length)
        args.option_selector.option_transformer.max_ep_len = args.env.eval_episode_factor * \
            int(max_length)

    option_selector_args = dict(args.option_selector)
    option_selector_args['state_dim'] = state_dim
    option_selector_args['option_dim'] = args.option_dim
    option_selector_args['codebook_dim'] = args.codebook_dim
    state_reconstructor_args = dict(args.state_reconstructor)
    lang_reconstructor_args = dict(args.lang_reconstructor)
    decision_transformer_args = {'state_dim': state_dim,
                                 'action_dim': action_dim,
                                 'option_dim': args.option_dim,
                                 'discrete': args.env.discrete,
                                 'hidden_size': args.dt.hidden_size,
                                 'use_language': args.method == 'vanilla',
                                 'use_options': args.method != 'vanilla',
                                 'option_il': args.dt.option_il,
                                 'max_length': max_length if args.method != 'traj_option' else args.model.K,
                                 'max_ep_len': args.env.eval_episode_factor*max_length,
                                 'action_tanh': False,
                                 'n_layer': args.dt.n_layer,
                                 'n_head': args.dt.n_head,
                                 'n_inner': 4*args.dt.hidden_size,
                                 'activation_function': args.dt.activation_function,
                                 'n_positions': args.dt.n_positions,
                                 'n_ctx': args.dt.n_positions,
                                 'resid_pdrop': args.dt.dropout,
                                 'attn_pdrop': args.dt.dropout,
                                 }
    hrl_model_args = dict(args.model)

    iq_args = cfg.iq

    model = HRLModel(option_selector_args, state_reconstructor_args,
                     lang_reconstructor_args, decision_transformer_args, iq_args, device, **hrl_model_args)

    print(model)
    model = model.to(device=device)

    trainer_args = dict(args.trainer)

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        optimizer=None,
        train_loader=train_loader,
        env=env,
        env_name=env_name,
        val_loader=None,
        scheduler=None,
        skip_words=args.env.skip_words,
        **trainer_args
    )

    # Restore trainer from checkpoint
    trainer.load(args.checkpoint_path)
    trainer.evaluate(iter_num=0, render=args.render, max_ep_len=500, render_path=args.render_path)


def train(args):
    device = args.trainer.device

    args.method = args.model.name
    exp_name = f'{args.project_name}-{args.train_dataset.num_trajectories}-{args.method}'
    args.savepath = f'{args.hydra_base_dir}/{args.savedir}/{exp_name}-{datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'

    if args.wandb:
        wandb.init(
            name=exp_name,
            group=args.method,
            project=f'hrl_{args.env.name}',
            config=dict(args),
            entity='language-rl'
        )

    if not os.path.isdir(args.savepath):
        os.makedirs(args.savepath, exist_ok=True)

    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    logger = Logger(os.path.join("log"), args.project_name, args.seed)
    #K = args['K']
    batch_size = args.batch_size

    train_dataset_args = dict(args.train_dataset)
    if 'BabyAI' in args.env.name:
        train_dataset = ExpertDataset(**train_dataset_args, use_direction=args.env.use_direction)
    elif 'Lorl' in args.env.name:
        train_dataset = ExpertDataset(**train_dataset_args, use_state=args.env.use_state)
    elif 'Hopper' in args.env.name:
        train_dataset = ExpertDataset(**train_dataset_args)
    elif 'kitchen' or 'Grid' in args.env.name:
        train_dataset = ImagePairDataset(train_dataset_args["expert_location"], env_name = args.env.name)
    else:
        raise NotImplementedError
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=False, pin_memory=True, drop_last=False)

    print('=' * 50)
    print(f'Starting new experiment: {args.env.name} {args.train_dataset.num_trajectories}')
    print('=' * 50)
 
    state_dim = args.env.state_dim
    action_dim = args.env.action_dim
    if isinstance(state_dim, str):
        state_dim = ast.literal_eval(state_dim)

    if isinstance(state_dim, tuple):
        assert not args.trainer.state_il, "Cannot do state imitation learning with an image input"
    
    if not 'Crafter' in args.env.name:
        env = gym.make(args.env.name)
    else: 
        env = None
        
    env_name = args.env.name
    if not args.env.eval_offline:
        val_loader = None
    else:
        val_dataset_args = dict(args.val_dataset)
        if 'BabyAI' in args.env.name:
            val_dataset = ExpertDataset(**val_dataset_args, use_direction=args.env.use_direction)
        elif 'lorel' in args.env.name:
            val_dataset = ExpertDataset(**val_dataset_args, use_state=args.env.use_state)
        elif 'kitchen' or 'Grid' in args.env.name:
            train_dataset = ImagePairDataset(train_dataset_args["expert_location"])
        else:
            raise NotImplementedError
        val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
                                shuffle=False, pin_memory=True, drop_last=False)

    if 'BabyAI' in args.env.name:
        state_dim += 4*args.env.use_direction

    if args.method == 'traj_option':
        args.option_selector.option_transformer.max_length = int(train_dataset.max_length)
        args.option_selector.option_transformer.max_ep_len = args.env.eval_episode_factor * \
            int(train_dataset.max_length)

    if args.model.horizon == 'max':
        args.model.horizon = int(train_dataset.max_length)
    if args.model.K == 'max':
        args.model.K = int(train_dataset.max_length)

    option_selector_args = dict(args.option_selector)
    option_selector_args['state_dim'] = state_dim
    option_selector_args['option_dim'] = args.option_dim
    option_selector_args['codebook_dim'] = args.codebook_dim
    state_reconstructor_args = dict(args.state_reconstructor)
    lang_reconstructor_args = dict(args.lang_reconstructor)
    decision_transformer_args = {'state_dim': state_dim,
                                 'action_dim': action_dim,
                                 'option_dim': args.option_dim,
                                 'discrete': args.env.discrete,
                                 'hidden_size': args.dt.hidden_size,
                                 'use_language': args.method == 'vanilla',
                                 'use_options': args.method != 'vanilla',
                                 'option_il': args.dt.option_il,
                                 'predict_q': args.use_iq,
                                 'max_length': train_dataset.max_length if 'option' not in args.method else args.model.K,   # used to be K
                                 'max_ep_len': args.env.eval_episode_factor*train_dataset.max_length,
                                 'n_layer': args.dt.n_layer,
                                 'n_head': args.dt.n_head,
                                 'activation_function': args.dt.activation_function,
                                 'n_positions': args.dt.n_positions,
                                 'n_ctx': args.dt.n_positions,
                                 'resid_pdrop': args.dt.dropout,
                                 'attn_pdrop': args.dt.dropout,
                                 'no_states': args.dt.no_states,
                                 'no_actions': args.dt.no_actions,
                                 }
    hrl_model_args = dict(args.model)
    iq_args = args.iq

    model = HRLModel(option_selector_args, state_reconstructor_args, 
                     lang_reconstructor_args, decision_transformer_args, iq_args, device, state_dim=state_dim, 
                     action_dim=action_dim, **hrl_model_args)

    start_iter = 1
    if args.resume:
        args.warmup_steps = 0
        #checkpoint = trainer.load(args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path)

        model_checkpoint = {}
        for key, value in checkpoint['model'].items():
            if key.startswith("module"):
                new_key = key.replace("module.", "")
            else:
                new_key = key
            model_checkpoint[new_key] = value

        checkpoint["model"] = model_checkpoint    
        model.load_state_dict(checkpoint['model'])
        
        start_iter = checkpoint['iter_num'] + 1
        assert train_dataset.max_length == checkpoint[
            'train_dataset_max_length'], f"Expected max length of dataset to be {train_dataset.max_length} but got {checkpoint['train_dataset_max_length']}"
        
    if args.load_options:
        checkpoint = torch.load(args.checkpoint_path)
        checkpoint = checkpoint['model']
        state_dict = {k:v for k,v in checkpoint.items() if k.startswith('option_selector.Z')}
        loaded = model.load_state_dict(state_dict, strict=False)
        assert loaded.unexpected_keys == []   ## simple check
        if args.freeze_loaded_options:
            for name, param in model.named_parameters():
                if name.startswith('option_selector.Z'):
                    param.requires_grad = False
            assert not model.option_selector.Z.project_out.bias.requires_grad   ## simple check

    if args.parallel:
        model = torch.nn.DataParallel(model).to(device)
    else:
        model = model.to(device=device)

    # Setting up the optimizer
    params = [(k, v) for k, v in model.named_parameters() if v.requires_grad]
    # setting different learning rates for the LM part, OS part and other parts
    lm_params = {'params': [v for k, v in params if k.startswith(
        'lm.')], 'lr': args.lm_learning_rate}
    bc_params = {'params': [v for k, v in params if k.startswith(
        'bc_policy.')], 'lr': args.bc_learning_rate}
    os_params = {'params': [v for k, v in params if k.startswith(
        'option_selector.')], 'lr': args.os_learning_rate, 'weight_decay': args.weight_decay}
    other_params = {'params': [v for k, v in params if not (k.startswith('lm.') or k.startswith('option_selector.') or k.startswith('bc_policy.'))], 'lr': args.learning_rate, 'weight_decay': args.weight_decay}
    # for the option selector need separate lr?
    optimizer = torch.optim.AdamW(
        [other_params, os_params, bc_params]
        )

    def adjust_lr(steps):
        if steps < args.warmup_steps:
            return min((steps + 1) / args.warmup_steps, 1)
        num_decays = (steps + 1) // args.decay_steps
        return args.lr_decay ** (num_decays)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, adjust_lr)

    trainer_args = dict(args.trainer)

    trainer = Trainer(
        args=args,
        model=model,
        tokenizer=tokenizer,
        optimizer=optimizer,
        train_loader=train_loader,
        env=env,
        env_name=env_name,
        val_loader=val_loader,
        scheduler=scheduler,
        eval_episode_factor=2,
        skip_words=args.env.skip_words,
        **trainer_args
    )

    # Training loop
    for iter_num in range(start_iter, start_iter + args.max_iters):
        outputs = trainer.train_iteration(
            iter_num=iter_num, llm_image_save_dir=args["llm_segment_image"]["save_dir"], print_logs=True, eval_render=args.render)

        for log_name, log_value in outputs.items():
            if isinstance(log_value, float) or isinstance(log_value, int):
                logger.log_var(log_name, log_value, iter_num)
                logger.log_str("{}:{}".format(log_name, log_value))
        logger.log_str('indices:{}'.format(str(outputs["training/indices"])))
        logger.log_str("==" * 25 + str(iter_num) + "==" * 25)

        # if args.wandb and iter_num % args.log_interval == 0:
        #     wandb.log(outputs, step=iter_num)

        if iter_num % args.save_interval == 0 or iter_num % trainer.llm_update_epoch == 0:
            trainer.save(iter_num, f'{args.savepath}/model_{iter_num}.ckpt', args)


def get_args(cfg: DictConfig):
    cfg.trainer.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    cfg.hydra_base_dir = os.getcwd()
    return cfg


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    args = get_args(cfg)

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(args.trainer.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    print("--> Running in ", os.getcwd())

    if args.eval:
        evaluate(cfg)
        return

    # train
    print(OmegaConf.to_yaml(cfg))
    train(args)


if __name__ == "__main__":
    main()
