from transformers import T5ForConditionalGeneration, T5Tokenizer
import json
from tqdm import tqdm
import os
import torch
import shutil
import argparse


def add_guid(source_dir: str, target_dir: str, taks_num: int):
    """
    add guid for examples in training/validation/test set
    """
    guid = 0
    for task_id in range(taks_num):
        os.makedirs(f'{target_dir}/task_{task_id}', exist_ok=True)
        if os.path.exists(f'{source_dir}/task_{task_id}/tables.json'):
            shutil.copy(f'{source_dir}/task_{task_id}/tables.json', f'{target_dir}/task_{task_id}/tables.json')

        for mode in ['train', 'dev', 'test']:
            ex_lst = json.load(open(f'{source_dir}/task_{task_id}/{mode}.json', 'r', encoding='utf-8'))

            for ex in ex_lst:
                ex['guid'] = guid
                guid += 1

            json.dump(ex_lst, open(f'{target_dir}/task_{task_id}/{mode}.json', 'w', encoding='utf-8'), indent=4)

def is_float(str):
    s=str.split('.')
    if len(s)>2:
        return False
    else:
        for si in s:
            if not si.isdigit():
                return False
        return True


class ContextRetriever:
    def __init__(self, source_set: list, candidate_set: list, t5_dir: str, batch_size: int = 128):
        self.source_set = source_set        # examples waiting to retrieve their demonstrations
        self.candidate_set = candidate_set  # candidate context examples
        self.t5_dir = t5_dir
        self.batch_size = batch_size
        self.tokenizer = T5Tokenizer.from_pretrained(self.t5_dir)
        self.plm = T5ForConditionalGeneration.from_pretrained(self.t5_dir).to('cuda')

    def batch(self, iterable):
        l = len(iterable)
        for ndx in range(0, l, self.batch_size):
            yield iterable[ndx: min(ndx + self.batch_size, l)]
    
    def make_sentence(self, ex):
        if 'context' in ex and args.use_context:
            sent = ' | '.join([ex['context'], ex['question']])
        else:
            sent = ex['question']
        return sent

    def retrieve_demonstrations(self):
        demonstrations = []
        for ex in tqdm(self.source_set):
            # find context examples from candidate_set for each example in source_set

            source_demonstrations = []
            t5_input_pairs = []
            for candidate_ex in self.candidate_set:
                sent = self.make_sentence(ex)
                candidate_sent = self.make_sentence(candidate_ex)

                t5_input_pairs.append(f'stsb sentence1: {sent} sentence2: {candidate_sent}')

            similarity_scores = []
            with torch.no_grad():
                for input_batch in self.batch(t5_input_pairs):
                    input_ids = self.tokenizer(input_batch, return_tensors="pt", padding=True).to('cuda')
                    output = self.plm.generate(**input_ids)
                    similarity_scores += self.tokenizer.batch_decode(output, skip_special_tokens=True)

            similarity_scores = [x if is_float(x) else '0.0' for x in similarity_scores]
            candidates_to_scores = [list(x) for x in zip(self.candidate_set, map(float, similarity_scores))]
            candidates_to_scores = sorted(candidates_to_scores, key=(lambda x: x[1]), reverse=True)

            for candidate_to_score in candidates_to_scores:
                if candidate_to_score[1] > 0:
                    source_demonstrations.append([candidate_to_score[0], candidate_to_score[1]])

            demonstrations.append(source_demonstrations)
        return demonstrations


if __name__ == '__main__':

    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--dataset', type=str, default="spider")
    arg_parser.add_argument('--use_context', action='store_true')
    arg_parser.add_argument('--gpu', type=str, default="3")
    args = arg_parser.parse_args()

    source_dir = f'./data/task_splits/{args.dataset}'
    target_dir = f'./data/task_splits/{args.dataset}_context'

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    task_num = len([x for x in os.listdir(source_dir)])
    print("Task Number:", task_num)

    if not os.path.exists(target_dir):
        add_guid(source_dir, target_dir, task_num)

    """
    requires a mapper to map task_id to original_task_id (actual task files)
    task_id_mapper = {
        task_id: original_task_id,
        ...
    }
    """
    for mode in ["train", "dev", "test"]:
        for task_id in range(task_num):
            """
            test on task_1's test set
            when processing training set, one should remove duplicate examples in candidate_set
            """
            if task_id < 14 and mode == 'train':
                continue

            source_set = json.load(open(f'{target_dir}/task_{task_id}/{mode}.json', 'r', encoding='utf-8'))
            print(f'{target_dir}/task_{task_id}/{mode}.json')
            candidate_set = []
            for pre_task_id in range(task_num):
                # load all previous training sets and current training set as candidate set
                candidate_set += json.load(open(f'{target_dir}/task_{pre_task_id}/train.json', 'r', encoding='utf-8'))

            retrieve_args = {
                'source_set': source_set,
                'candidate_set': candidate_set,
                't5_dir': '../../../resources/t5-base',
                'batch_size': 1024
            }

            retriever = ContextRetriever(**retrieve_args)
            demonstrations = retriever.retrieve_demonstrations()

            for i, ex in enumerate(source_set):
                ex['demonstration_examples'] = []
                for _ex in demonstrations[i]:
                    ex['demonstration_examples'].append([_ex[0]['guid'], _ex[1]])

            json.dump(source_set, open(f'{target_dir}/task_{task_id}/{mode}.json', 'w', encoding='utf-8'), indent=4)
