import argparse

import torch
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, GPT2TokenizerFast, get_scheduler
from accelerate import Accelerator

from promptrl.trainer import Trainer, VizTrainer, VHomeTrainer
from promptrl.agent import IterPromptAgent, RankPromptAgent, CaptionPromptAgent, BeamAffordanceAgent
from promptrl.task import make_env, make_dataset, make_vhome_envs
from promptrl.utils import Logger
from load import load_models, load_set_optim

parser = argparse.ArgumentParser()

def str2bool(s):
    if isinstance(s, bool):
        return s
    if 'true'.startswith(s.lower()):
        return True
    elif 'false'.startswith(s.lower()):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

# Same default parameters as run_clm_no_trainer.py in tranformers
# https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
parser.add_argument('--seed', type=int, default=42,
                    help='Random seed.')
parser.add_argument('--epochs', type=int, default=1,
                    help='Num training epochs')
parser.add_argument('--train-samples', type=int, default=float('inf'),
                    help='Number of training samples.')
parser.add_argument('--val-split-ratio', type=float, default=None,
                    help='Ratio of train/val split. None for no val split.')
parser.add_argument('--eval-samples', type=int, default=10,
                    help='Number of eval samples.')
parser.add_argument('--eval-demo-samples', type=int, default=200,
                    help='Number of eval demonstrations to evaluate on (no interactive)')
parser.add_argument('--batch-size', type=int, default=16,
                    help='Number of training samples.')
parser.add_argument('--grad-accum-steps', type=int, default=1,
                    help='Number of batches before backward pass.')
parser.add_argument('--lm-learning-rate', type=float, default=5e-5,
                    help='Learning rate.')
parser.add_argument('--lm-weight-decay', type=float, default=0.001,
                    help='Weight decay.')
parser.add_argument('--max-grad-norm', type=float, default=None,
                    help='Clip grad norm at max')
parser.add_argument('--learning-rate', type=float, default=0.01,
                    help='Learning rate.')
parser.add_argument('--weight-decay', type=float, default=0.01,
                    help='Weight decay.')
parser.add_argument('--lr-scheduler-type', type=str, default='linear',
                    help='Learning rate scheduler type')
parser.add_argument('--num-warmup-steps', type=int, default=0,
                    help='Weight decay.')
parser.add_argument('--checkpoint-every-epoch', type=int, default=-1,
                    help='How often to checkpoint model (-1 is no checkpointing)')
parser.add_argument('--load-checkpoint', type=str, default=None,
                    help='Load checkpoint from (path)')

# Misc
parser.add_argument('--eval-only', type=str2bool, default=False,
                    help='Only evaluate, no training')
parser.add_argument('--train-only', type=str2bool, default=False,
                    help='Only train, no evaluate')
parser.add_argument('--log-every', type=int, default=100,
                    help='Log every _ steps')
parser.add_argument('--eval-every', type=int, default=500,
                    help='Evaluate every _ steps')
parser.add_argument('--eval-every-epoch', type=int, default=1,
                    help='Evaluate every _ epochs')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disable CUDA training.')
parser.add_argument('--exp-id', type=str, default=None,
                    help='Experiment ID for logging')
parser.add_argument('--no-log-csv', action='store_true', default=False,
                    help='Disable csv logging')
parser.add_argument('--no-wandb', action='store_true', default=False,
                    help='Disable wandb logging.')
parser.add_argument('--log-dir', type=str, default=None,
                    help='Path to log results. Default: None (no log directory)')

# Task args
parser.add_argument('--task-type', type=str, default='alf-all',
                    help='Task type')
parser.add_argument('--task-partition', type=int, default=None,
                    help='Partition eval tasks (to evaluate in parallel)')
parser.add_argument('--limit-frames', type=int, default=1,
                    help='Limit observation frames to X.')
parser.add_argument('--obs-type', type=str, default='lang',
                    help='Observation type for predicting prompt')
parser.add_argument('--num-aux-samples', type=int, default=1000,
                    help='If aux obs type, how many samples to use for auxilliary tasks.')
parser.add_argument('--mix-ood-aux-samples', type=str2bool, default=False,
                    help='Add auxiliary task samples from ood to train')
parser.add_argument('--mix-id-aux-samples', type=str2bool, default=False,
                    help='Add auxiliary task samples from id to train')
parser.add_argument('--prim-lw', type=str, default='unif.100',
                    help='Loss weight schedule for primary task (action pred)')
parser.add_argument('--aux-lw', type=str, default='unif.100',
                    help='Loss weight schedule for aux tasks')
parser.add_argument('--min-action-burn-in', type=int, default=5,
                    help='Minimum previous actions/observations before LM makes predictions')
parser.add_argument('--dynamic-length-eval', type=str2bool, default=True,
                    help='Finetune the language model')
parser.add_argument('--max-context-tokens', type=int, default=1000,
                    help='Minimum previous actions/observations before LM makes predictions')
parser.add_argument('--rank-admissible', type=str2bool, default=False,
                    help='Predict actions by ranking admissible actions')
parser.add_argument('--beam-affordance', type=str2bool, default=False,
                    help='Predict actions by ranking top k (generated with beam search), score=logits x affordance value (saycan). affordance from trained affordance function')
parser.add_argument('--beam-search-k', type=int, default=30,
                    help='beam affordance k value')
parser.add_argument('--affordance-epsilon', type=float, default=1e-2,
                    help='Minimum affordance prob')

# Prompt tuning args
parser.add_argument('--n-prompt-tokens', type=int, default=20,
                    help='Size of soft prompt')
parser.add_argument('--n-deep-prompt-tokens', type=int, default=0,
                    help='Size of deep soft prompt')
parser.add_argument('--n-obs-embed-tokens', type=int, default=5,
                    help='Size of soft prompt embedded from observation')
parser.add_argument('--prompt-arch', type=str, default='unit',
                    help='Prompt prediction architecture')
parser.add_argument('--shared-prompt-proj', type=str2bool, default=False,
                    help='Used shared projector for each prompt token')
parser.add_argument('--prompt-dropout-prob', type=float, default=0.0,
                    help='Dropout probability for soft prompts')
parser.add_argument('--init-from-vocab', type=str2bool, default=False,
                    help='Initialize soft prompts with pretrained word embeddings')
parser.add_argument('--lm-base', type=str, default='gpt2',
                    help='Language model to use')
parser.add_argument('--finetune-lm', type=str2bool, default=False,
                    help='Finetune the language model')
parser.add_argument('--lm-eval-mode', type=str2bool, default=False,
                    help='Set LM to eval always')
parser.add_argument('--detach-obs-forward', type=str2bool, default=False,
                    help='Remove gradient for obs encoder on forward tasks (predict next action)')

def main(args):
    #torch.autograd.set_detect_anomaly(True)
    if args.task_type.startswith('alf'):
        args.task_kind = 'alf'
    elif args.task_type.startswith('virtualhome'):
        args.task_kind = 'virtualhome'
    else:
        args.task_kind = None
    logger = Logger(args, use_wandb=not args.no_wandb, log_csv=not args.no_log_csv, log_dir=args.log_dir)
    accelerator = Accelerator(cpu=args.no_cuda)
    torch.manual_seed(args.seed)

    tokenizer, model, prompt_model = load_models(logger, args)

    per_obs_tokens = prompt_model._per_obs_tokens
    if not args.eval_only:
        logger.info('Loading dataset...')
        train_dataset, eval_datasets = make_dataset(args, model, tokenizer, per_obs_tokens=per_obs_tokens)

        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=train_dataset.get_collator())
        eval_dataloader_dict = {_name: DataLoader(_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=_dataset.get_collator()) for _name, _dataset in eval_datasets.items()}
    else:
        train_dataset, eval_datasets = None, {}
        train_dataloader = None
        eval_dataloader_dict = {}

    # guided_eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=True, collate_fn=eval_dataset.get_collator())
    guided_eval_dataloader = None
    if args.task_kind == 'alf':
        envs = make_env(args.task_type, args.obs_type, tokenizer)
    elif args.task_kind == 'virtualhome':
        envs = make_vhome_envs(args.task_type, args.task_partition)
    else:
        envs = None
    if train_dataset is not None:
        logger.info(f'Train size: {len(train_dataset)}.')
        logger.info(f'Eval sizes: {[(_name, len(_dataset)) for _name, _dataset in eval_datasets.items()]}')
        args.train_samples = len(train_dataset)

    logger.info('Preparing model...')
    model, prompt_model = accelerator.prepare(model, prompt_model)
    optimizer, lr_scheduler = load_set_optim(args, model, prompt_model)
    train_dataloader, guided_eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        train_dataloader, guided_eval_dataloader, optimizer, lr_scheduler
    )
    # prepare all eval dataloaders
    eval_keys = list(eval_dataloader_dict.keys())
    if len(eval_keys) > 1:
        eval_prepared = accelerator.prepare(*(eval_dataloader_dict[key] for key in eval_keys))
        eval_dataloaders = dict(zip(eval_keys, eval_prepared))
    elif len(eval_keys) == 1:
        eval_prepared = accelerator.prepare(eval_dataloader_dict[eval_keys[0]])
        eval_dataloaders = {eval_keys[0]: eval_prepared}
    else:
        eval_dataloaders = {}

    logger.info('Training model...')
    if args.obs_type.startswith('img') and args.task_kind == 'alf':
        trainer_cls = VizTrainer
    elif args.obs_type.startswith('img') and args.task_kind == 'virtualhome':
        trainer_cls = VHomeTrainer
    else:
        trainer_cls = Trainer
    assert not (args.rank_admissible and args.beam_affordance)
    if args.rank_admissible:
        agent_cls = RankPromptAgent
    elif args.beam_affordance:
        agent_cls = BeamAffordanceAgent
    elif args.obs_type.startswith('img') and args.prompt_arch == 'unit':
        agent_cls = CaptionPromptAgent
    else:
        agent_cls = IterPromptAgent

    agent = agent_cls(accelerator, args, logger, tokenizer, model, prompt_model)
    trainer = trainer_cls(accelerator, args, logger, tokenizer, agent, optimizer, lr_scheduler)
    if args.load_checkpoint is not None:
        trainer.load(args.load_checkpoint)
    if not args.eval_only:
        trainer.train(train_dataloader, eval_dataloaders, guided_eval_dataloader, envs)
    if not args.train_only:
        trainer.eval_only(eval_dataloaders, envs)
    logger.exit()

if __name__ == '__main__':
    args = parser.parse_args()
    #args.cuda = not args.no_cuda and torch.cuda.is_available()

    args.device = torch.device('cuda' if not args.no_cuda else 'cpu')
    main(args)
