import os
import json
import argparse
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from tqdm import tqdm

MODEL_PATH_DICT = {
    'gpt-neox': 'EleutherAI/gpt-neox-20b',
}

def generate(input_seq, model, tokenizer, stop_token_id, args):
    encoder_input = tokenizer(input_seq, truncation=True, return_tensors='pt')
    input_length = len(encoder_input.input_ids[0])

    num_batch = args.num_return_sequences // args.batch_size
    outputs = []
    for i in range(num_batch):
        with torch.no_grad():
            pred_ids = model.generate(
                  input_ids = encoder_input.input_ids.to(args.device), 
                  attention_mask = encoder_input.attention_mask.to(args.device),
                  max_length = input_length+args.generation_length,
                  eos_token_id = stop_token_id,
                  pad_token_id = tokenizer.eos_token_id,
                  early_stopping = True,
                  num_return_sequences = args.batch_size,
                  num_beams = args.num_beams,
                  do_sample = args.top_k > 0 or args.top_p < 1.0,
                  top_p = args.top_p,
                  top_k = args.top_k,
                  typical_p = args.typical_p,
                  use_cache = True,
                  temperature = 1.0 if args.top_k == 0 else 0.7,
                 )

        for beam in pred_ids:
            outputs.append(tokenizer.decode(beam[input_length:], skip_special_tokens=True).replace('\n', ''))

    return outputs

def main(args):
    # Check args
    assert args.num_shards > 0
    assert 0 <= args.shard < args.num_shards
    assert 0 <= args.gpu < torch.cuda.device_count()

    # Set parameters
    set_seed(args.seed)
    torch.set_num_threads(args.num_threads)
    args.device = torch.device('cuda:{}'.format(args.gpu))   

    # Initialize model
    model_path = MODEL_PATH_DICT[args.model]
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map = 'auto',
        cache_dir = '../model/cache',
        torch_dtype = 'auto',
        offload_folder = '../model/offload',
    )
    model.eval()

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir='../model/cache')
    stop_token_id = tokenizer.encode('\n\nQ')[-2]

    # Load prompts
    with open('../data/gpt/prompts/{}.explanation.txt'.format(args.prompt), 'r') as fr:
        prompt = json.load(fr)['prompt'].replace('Yes or no: ', '')
    print(prompt)

    for split in args.eval_split.split(','):
        # Load dataset
        with open('../data/gpt/data/{}/{}.jsonl'.format(args.dataset, split), 'r') as fr:
            all_lines = fr.readlines()
        shard_size = - (len(all_lines) // - args.num_shards)
        dataset = all_lines[(args.shard*shard_size):(args.shard+1)*shard_size]
        if args.debug:
            dataset = dataset[:2]

        # Create output path
        output_dir = '../data/gpt/output/{}'.format(args.dataset)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_path = '{}/{}.{}.{}.jsonl'.format(output_dir, split, args.prompt, args.shard) 
        
        for line in tqdm(dataset):
            # Get question and answer
            instance = json.loads(line)
            question = instance['question'] if 'question' in instance else instance['context']
            if 'choices' in instance:
                question += '\nAnswer Choices:'
                choices = instance['choices']
                answer_id = instance['answer']
                for choice_id, choice in enumerate(choices):
                    question += '\n({}) {}'.format(chr(ord('a')+choice_id), choice)
                answer = choices[answer_id]
            else:
                raise NotImplementedError

            # Format input sequence
            input_seq = prompt.format(question, answer)
            if args.debug:
                print(input_seq)

            # Generate FTR
            model_outputs = generate(input_seq, model, tokenizer, stop_token_id, args)
            generated_ftr = model_outputs[0]

            # Write output to file
            output_instance = instance.copy()
            output_instance['explanation'] = [generated_ftr]
            with open(output_path, 'a') as fw:
                fw.write(json.dumps(output_instance) + '\n')

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Run main.')

    # Model args
    parser.add_argument('--model', '-m', type=str, default='gpt-neox')

    # Data and GPU args
    parser.add_argument('--dataset', '-d', type=str)
    parser.add_argument('--prompt', '-p', type=str)
    parser.add_argument('--eval_split', type=str, default='test,dev,train')
    parser.add_argument('--num_shards', type=int, default=1)
    parser.add_argument('--shard', type=int, default=0)
    parser.add_argument('--num_threads', type=int, default=4)
    parser.add_argument('--gpu', type=int, default=0)

    # Decoding args
    parser.add_argument('--sample', action='store_true')
    parser.add_argument('--num_beams', type=int, default=1)
    parser.add_argument('--top_k', type=int, default=0)
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--typical_p', type=float, default=1.0)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_return_sequences', type=int, default=1)
    parser.add_argument('--generation_length', type=int, default=128)

    # Misc args
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--debug', action='store_true')

    # Parse args
    args = parser.parse_args()

    main(args)
