"""
Decode GSM8K training data using the T5 model.
TODO: adaptive batch size, such that max_len * batch_size = const 
"""

import time 
import torch
import re
import argparse
import os
import pytz 
import hydra
import json
import pickle

import numpy as np
import torch.nn.functional as F

from datetime import datetime
from tqdm import tqdm
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration
from omegaconf import DictConfig, OmegaConf
from src.utils import tprint, parse_pred_ans
from test_distill_multiple import load_test_data
#from Flan_memory import MemoryT5
from transformers import T5Config

# GSM8K_VALIDATION_INDEX_PATH = 'lib_prompt/validation_index.npy'
# MULTIARITH_PATH = 'data/multiarith/MultiArith.json'
# MULTIARITH_VALIDATION_INDEX_PATH = 'data/multiarith/validation_index.npy'


# def load_test_data(test_data):
#     # TODO: add multiarith/ other math datasets
#     if(test_data == 'gsm8k_dev'):
#         gsm8k = load_dataset('gsm8k', 'main')
#         validation_index = np.load(GSM8K_VALIDATION_INDEX_PATH)
#         data = gsm8k['train'].select(validation_index)
#         data_ = []
#         for q, a in zip(data['question'], data['answer']): 
#             data_.append({'question': q, 'answer': a})
#     elif(test_data == 'gsm8k_test'):
#         gsm8k = load_dataset('gsm8k', 'main')
#         data = gsm8k['test']
#         data_ = []
#         for q, a in zip(data['question'], data['answer']): 
#             data_.append({'question': q, 'answer': a})
#     elif(test_data == 'multiarith_test'):
#         dataset = json.load(open(MULTIARITH_PATH))
#         dev_ind = np.load(MULTIARITH_VALIDATION_INDEX_PATH)
#         # dev_data = [dataset[i] for i in dev_ind]
#         test_data = [d for i, d in enumerate(dataset) if i not in dev_ind]
#         data_ = []
#         for d in test_data:
#             data_.append({'question': d['sQuestion'][1:-1], 'answer': d['lSolutions']})
#     else:
#         raise ValueError('Invalid test data: %s' % test_data)
#     return data_


@hydra.main(version_base=None, config_path="src/conf", config_name="config_inference")
def main(args : DictConfig):

    print(OmegaConf.to_yaml(args))
    if(args.batch_size_fixed != -1): args.batch_size = args.batch_size_fixed
    #args.batch_size=10 # reverse this later
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    # load the dataset
    dataset = load_test_data(args.test_data)#args.test_data or "gsm8k_dev" 
    args.base_model ="google/flan-t5-xl" #using finetuned flant5 to zero shot at training set to see what it gets wrong. 

    # load the model
    tprint('Loading the model from %s' % args.base_model)
    start_time = time.time()
    tokenizer = T5Tokenizer.from_pretrained(args.base_model)
    '''
    #memoryT5 code #also uncomment import for this. 
    config=T5ForConditionalGeneration.from_pretrained('google/flan-t5-base').config
    batch_size=10
    max_knn_memories=64000
    num_retrieved_memories=32
    model= MemoryT5(config, batch_size=batch_size, max_knn_memories=max_knn_memories, knn_memory_multiprocessing=False, num_retrieved_memories=32)
    model.load_state_dict(torch.load('/home/jp98084/slm/FlanT5-CoT-Specialization-main/FlanT5-CoT-Specialization-main/memoryt5_checkpoints/0.0.2.2_epoch_4_iter_20000/pytorch_model.bin'))
    '''
    model = T5ForConditionalGeneration.from_pretrained(args.base_model)
    #model.load_state_dict(torch.load(args.base_model+'/pytorch_model.bin', map_location='cuda:0'))

    if(args.model_size == '11b'):
        model.parallelize(args.device_map)
    else:
        model.to('cuda:' + str(args.gpu_id))

    tprint('Model loaded in %.1f seconds.' % (time.time() - start_time))

    # load the prompt
    prompt = open(args.prompt_path).read()
    #directory_path = args.prompt_path.rpartition('/')[0]
    #dict_prompt=None
    #with open(directory_path+'/test_prompts.pkl', 'rb') as f:
        #dict_prompt=pickle.load(f)

    # decode the dataset
    tprint('Start decoding ... ')
    i = 0
    output_path = args.output_path + args.test_data + '_' + args.base_model.split('/')[-1] + '.txt'
    tprint('Model output to: %s' % output_path)

    # TODO: change this to batch version
    with open(output_path, 'w') as fd:
        tqdm_total = len(dataset) // args.batch_size
        if(len(dataset) % args.batch_size != 0): tqdm_total += 1 # reverse this too, as it excludes last batch
        for i in tqdm(range(0, len(dataset), args.batch_size), total=tqdm_total):
        #for i in tqdm(range(0, len(dataset) // args.batch_size * args.batch_size, args.batch_size), total=tqdm_total):
            questions = []
            q_batch = []
            a_batch = []
            for k in range(args.batch_size):
                if(i + k >= len(dataset)): break
                
                q = dataset[i + k]['question']
                q_batch.append(q)
                a = dataset[i + k]['answer']
                a_batch.append(a)
                
                #prompt_q = dict_prompt[q] + '\nQ: ' + q + '\n'
                prompt_q = prompt + '\nQ: ' + q + '\n'

                prompt_q += "Let's think step by step\n"
                questions.append(prompt_q)
                
            inputs = tokenizer(questions, padding=True, return_tensors="pt")
            with torch.no_grad():
                outputs = model.generate(inputs['input_ids'].to(model.device), 
                                         attention_mask=inputs['attention_mask'].to(model.device), 
                                         max_length=256
                                         )
            
            for q, a, ans_ in zip(q_batch, a_batch, outputs):
                ans_ = tokenizer.decode(ans_).replace('<pad>', '').strip()
                fd.write('Q: %s\nA_model:\n%s\nA:\n%s\n\n' % (q, ans_, a))

    _, _, _, _ = parse_pred_ans(output_path)
    return 

if __name__ == '__main__':
  main()