import torch
from transformers import GPT2Tokenizer, GPT2TokenizerFast, get_scheduler

from promptrl.model import GPT2PromptInputLM, GPTNeoPromptTuningLM, GPTJPromptTuningLM, OPTPromptTuningLM
import promptrl.prompts as prompts
from promptrl.clipcap import load_pretrained as load_clipcap
from promptrl.agent import IterPromptAgent, RankPromptAgent

def load_base(args):
    # language model
    if args.lm_base.startswith('gpt2') or args.lm_base == 'distilgpt2':
        tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token
        model = GPT2PromptInputLM.from_pretrained(
            args.lm_base,
        )
        model.config.pad_token_id = model.config.eos_token_id
    elif args.lm_base == 'gptneo':
        tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-1.3B")
        tokenizer.pad_token = tokenizer.eos_token
        model = GPTNeoPromptTuningLM.from_pretrained(
            "EleutherAI/gpt-neo-1.3B",
        )
        model.config.pad_token_id = model.config.eos_token_id
    elif args.lm_base == 'gptj':
        tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-j-6B")
        tokenizer.pad_token = tokenizer.eos_token
        model = GPTJPromptTuningLM.from_pretrained(
            "EleutherAI/gpt-j-6B",
        )
        model.config.pad_token_id = model.config.eos_token_id
    elif args.lm_base.startswith('opt'):
        load_path = f'facebook/{args.lm_base}'
        tokenizer = GPT2Tokenizer.from_pretrained(load_path)# OPT does not have fast tokenizer
        tokenizer.pad_token = tokenizer.eos_token
        model = OPTPromptTuningLM.from_pretrained(
            load_path,
        )
        model.config.pad_token_id = model.config.eos_token_id
    else:
        raise  NotImplementedError(f'LM {args.lm_base} not implemented')

    if not args.finetune_lm:
        for param in model.parameters():
            param.requires_grad = False

    return tokenizer, model

def get_num_tasks(args):
    if 'aux' in args.obs_type:
        if args.obs_type in ['imgauxcap', 'imgauxinvdyn', 'imgauxgoalp']:
            num_tasks = 2
        elif args.obs_type in ['imgauxc2', 'imgauxcg']:
            num_tasks = 3
        else:
            raise NotImplementedError
    elif args.obs_type in ['imgcap', 'imgadm']:
        num_tasks = 2
    else:
        num_tasks = 1
    return num_tasks

def load_prompt(args, tokenizer, model):
    num_tasks = get_num_tasks(args)

    if args.prompt_arch == 'unit':
        prompt_model = prompts.TextUnitPrompt(num_tasks, model, args.n_prompt_tokens, deep_prompt_dim=args.n_deep_prompt_tokens, obs_embed_dim=args.n_obs_embed_tokens, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab)
    elif args.prompt_arch == 'clip':
        prompt_model = prompts.CLIPEmbedPrompt(num_tasks, model, args.n_prompt_tokens, args.n_deep_prompt_tokens, args.n_obs_embed_tokens, shared_proj=args.shared_prompt_proj, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab, patched=False, detach_obs_forward=args.detach_obs_forward)
    elif args.prompt_arch == 'clippatch':
        prompt_model = prompts.CLIPEmbedPrompt(num_tasks, model, args.n_prompt_tokens, args.n_deep_prompt_tokens, args.n_obs_embed_tokens, shared_proj=args.shared_prompt_proj, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab, patched=True, detach_obs_forward=args.detach_obs_forward)
    elif args.prompt_arch == 'resnet':
        prompt_model = prompts.ResNetEmbedPrompt(num_tasks, model, args.n_prompt_tokens, args.n_deep_prompt_tokens, args.n_obs_embed_tokens, shared_proj=args.shared_prompt_proj, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab, patched=False, detach_obs_forward=args.detach_obs_forward, pretrained=False, freeze_resnet=False)
    elif args.prompt_arch == 'resnetpt':
        prompt_model = prompts.ResNetEmbedPrompt(num_tasks, model, args.n_prompt_tokens, args.n_deep_prompt_tokens, args.n_obs_embed_tokens, shared_proj=args.shared_prompt_proj, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab, patched=False, detach_obs_forward=args.detach_obs_forward, pretrained=True, freeze_resnet=False)
    elif args.prompt_arch == 'resnetfz':
        prompt_model = prompts.ResNetEmbedPrompt(num_tasks, model, args.n_prompt_tokens, args.n_deep_prompt_tokens, args.n_obs_embed_tokens, shared_proj=args.shared_prompt_proj, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab, patched=False, detach_obs_forward=args.detach_obs_forward, pretrained=True, freeze_resnet=True)
    elif args.prompt_arch == 'clip-goal':
        prompt_model = prompts.CLIPGoalPrompt(num_tasks, model, tokenizer, args.n_prompt_tokens, args.n_deep_prompt_tokens, args.n_obs_embed_tokens, shared_proj=args.shared_prompt_proj, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab, detach_obs_forward=args.detach_obs_forward)
    elif args.prompt_arch == 'imgdummy':
        prompt_model = prompts.ImgDummyPrompt(num_tasks, model, args.n_prompt_tokens, args.n_deep_prompt_tokens, dropout_prob=args.prompt_dropout_prob, lm_embed_init=args.init_from_vocab)
    else:
        raise NotImplementedError(f'{args.prompt_arch} arch not implemented')

    return prompt_model

def load_set_optim(args, model, prompt_model):
    # set optimizer
    optimizer_grouped_parameters = []
    if args.finetune_lm:
        optimizer_grouped_parameters.append({
            'params': model.parameters(),
            'weight_decay': args.lm_weight_decay,
            'lr': args.lm_learning_rate,
        })
    optimizer_grouped_parameters.append({
        'params': prompt_model.parameters(),
        'weight_decay': args.weight_decay,
        'lr': args.learning_rate
    })

    optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.epochs * (args.train_samples // (args.batch_size * args.grad_accum_steps)),
    )

    return optimizer, lr_scheduler

def load_models(logger, args):
    if args.prompt_arch == 'clipcap':
        num_tasks = get_num_tasks(args)
        logger.info('Loading clipcap...')
        tokenizer, model, prompt_model = load_clipcap(num_tasks, args.n_prompt_tokens, args.n_deep_prompt_tokens)
    else:
        logger.info('Loading language model...')
        tokenizer, model = load_base(args)
        logger.info('Loading prompt model...')
        prompt_model = load_prompt(args, tokenizer, model)

    return tokenizer, model, prompt_model

def load_checkpoint_agent(accelerator, logger, args):
    tokenizer, model, prompt_model = load_models(logger, args)
    model, prompt_model = accelerator.prepare(model, prompt_model)
    if args.rank_admissible:
        agent_cls = RankPromptAgent
    else:
        agent_cls = IterPromptAgent

    agent = agent_cls(accelerator, args, logger, tokenizer, model, prompt_model)
    if args.load_checkpoint is not None:
        checkpoint = torch.load(args.load_checkpoint)
        agent.load_state_dict(checkpoint['agent'])

    return agent
