import os,glob, subprocess, re, argparse
# os.environ['HF_HOME'] = "checkpoints"
from huggingface_hub import login
# #
from datasets import load_dataset
from datasets import concatenate_datasets


class processFuns:

    def processFunc_sst2(examples, prompt):
        sys = '''You are a helpful, respectful and honest sentiment analysis assistant. \
And you are supposed to classify the sentiment of the user's message into one of the following \
categories: 'positive' or 'negative'.'''
        output_texts = []
        for sent,label in zip(examples['sentence'],examples['label']):
            text = prompt.format(sys, 'Sentence: '+sent, 'positive' if label else 'negative')
            output_texts.append(text)
        return output_texts
    
    def processFunc_cola(examples, prompt):
        sys = '''You are a helpful, respectful and honest English grammar analysis assistant. \
And you are supposed to classify the grammer acceptence of the sentence of the user's message into one of the following \
categories: 'acceptable' or 'unacceptable'.'''
        output_texts = []
        for sent,label in zip(examples['sentence'],examples['label']):
            text = prompt.format(sys, 'Sentence: '+sent, 'acceptable' if label else 'unacceptable')
            output_texts.append(text)
        return output_texts
    
    def processFunc_rte(examples, prompt):
        sys = '''You are a helpful, respectful and honest English analysis assistant. \
And you are supposed to find out whether first sentence entails the second sentence in the user's message. \
Answer with one of the following: 'entailment' or 'not entailment'.'''
        output_texts = []
        for sent1,sent2,label in zip(examples['sentence1'],examples['sentence2'],examples['label']):
            text = prompt.format(sys, f'First Sentence: {sent1} Second Sentence: {sent2}', 'not entailment' if label else 'entailment')
            output_texts.append(text)
        return output_texts
    
    def processFunc_qnli(examples, prompt):
        sys = '''You are a helpful, respectful and honest English analysis assistant. \
And you are supposed to find out whether the answer in the user's message is corresponding to the user's question. \
Answer with one of the following: 'corresponding' or 'not corresponding'.'''
        output_texts = []
        for ques,sent,label in zip(examples['question'],examples['sentence'],examples['label']):
            text = prompt.format(sys, f'Question: {ques} Answer: {sent}', 'not corresponding' if label else 'corresponding')
            output_texts.append(text)
        return output_texts
    
    def processFunc_mrpc(examples, prompt):
        sys = '''You are a helpful, respectful and honest English analysis assistant. \
And you are supposed to find out whether first sentence and the second sentence in the user's message are semantically equivalent. \
Answer with one of the following: 'equivalent' or 'not equivalent'.'''
        output_texts = []
        for sent1,sent2,label in zip(examples['sentence1'],examples['sentence2'],examples['label']):
            text = prompt.format(sys, f'First Sentence: {sent1} Second Sentence: {sent2}', 'equivalent' if label else 'not equivalent')
            output_texts.append(text)
        return output_texts

    def processFunc_wnli(examples, prompt):
        sys = '''You are a helpful, respectful and honest English analysis assistant. \
And you are supposed to find out whether the paragraph entails the sentence in the user's message. \
Answer with one of the following: 'entailment' or 'not entailment'.'''

    #     if 'another_promt':
    #         print('Warning: another_prompt')
    #         sys = '''Please tell me whether the paragraph complys with the sentence.  \
    # Answer with 'entailment' or 'not entailment'.'''

        output_texts = []
        for sent1,sent2,label in zip(examples['sentence1'],examples['sentence2'],examples['label']):
            text = prompt.format(sys, f'Paragraph: {sent1} Sentence: {sent2}', 'not entailment' if label else 'entailment')
            output_texts.append(text)
        return output_texts
    
    def processFunc_qqp(examples, prompt):
        sys = '''You are a helpful, respectful and honest English analysis assistant. \
And you are supposed to find out whether first question and the second question in the user's message are semantically duplicate. \
Answer with one of the following: 'duplicate' or 'not duplicate'.'''
        output_texts = []
        for sent1,sent2,label in zip(examples['question1'],examples['question2'],examples['label']):
            text = prompt.format(sys, f'First Question: {sent1} Second Question: {sent2}', 'equivalent' if label else 'not equivalent')
            output_texts.append(text)
        return output_texts
    
    def processFunc_mnli(examples, prompt):
        answers = ['entailment','neutral','contradiction']
        sys = '''You are a helpful, respectful and honest English analysis assistant. \
And you are supposed to find out whether the premise entails the hypothesis (entailment), contradicts the hypothesis (contradiction), or neither (neutral) in the user's message. \
Answer with one of the following: 'entailment','neutral' or 'contradiction'.'''
        output_texts = []
        for sent1,sent2,label in zip(examples['premise'],examples['hypothesis'],examples['label']):
            text = prompt.format(sys, f'Premise: {sent1} Hypothesis: {sent2}', answers[label])
            output_texts.append(text)
        return output_texts
    
    
    def processFunc_clone_detection(examples, prompt):
        sys = '''You are a helpful coding assistant. Your job is to analyze code. \
And you are supposed to find out whether the following Function1 and Function2 are semantically equivalent. \
Answer with one of the following: 'equivalent' or 'not equivalent'.'''
        output_texts = []
        for func1,func2,label in zip(examples['func1'],examples['func2'],examples['label']):
            text = prompt.format(sys, f'Function1: {func1} Function2: {func2}', 'equivalent' if label else 'not equivalent')
            output_texts.append(text)
        return output_texts
    
    def processFunc_code_to_text(examples, prompt):
        sys = '''You are a helpful coding assistant. Your job is to analyze code. \
And you are supposed to generate English docstrings for following Code.'''
        output_texts = []
        for code,docstring in zip(examples['code'],examples['docstring']):
            text = prompt.format(sys, f'Code: {code}', docstring)
            output_texts.append(text)
        return output_texts
    
    def processFunc_text_to_code(examples, prompt):
        sys = '''You are a helpful coding assistant. Your job is to generate code. \
And you are supposed to generate Java code from an English natural language description.'''
        output_texts = []
        for code,nl in zip(examples['code'],examples['nl']):
            text = prompt.format(sys, f'Description: {nl}', code)
            output_texts.append(text)
        return output_texts
    
    def processFunc_defect_detection(examples, prompt):
        sys = '''You are a helpful coding assistant. Your job is to analyze code. \
And you are supposed to find out whether the following Function has a defect in it. \
Answer with one of the following: 'defective' or 'not defective'.'''
        output_texts = []
        for func,label in zip(examples['func'],examples['target']):
            text = prompt.format(sys, f'Function: {func}', 'defective' if label else 'not defective')
            output_texts.append(text)
        return output_texts




def get_processFunc(dataset):
    return getattr(processFuns, f'processFunc_{dataset}')
    

def get_dataset(dataset):
    if dataset == 'clone_detection':
        ds_train = load_dataset("google/code_x_glue_cc_clone_detection_big_clone_bench", split='train')
        ds_val = load_dataset("google/code_x_glue_cc_clone_detection_big_clone_bench", split='validation')
        ds_val = ds_val.filter(lambda x: x["id"] < 10000)
        return ds_train, ds_val
    if dataset == 'code_to_text':
        d_l = ['python', 'java']
        ds_trains = [load_dataset("google/code_x_glue_ct_code_to_text", i, split='train') for i in d_l]
        ds_vals = [load_dataset("google/code_x_glue_ct_code_to_text", i, split='validation') for i in d_l]
        ds_train, ds_val = concatenate_datasets(ds_trains), concatenate_datasets(ds_vals)
        ds_val = ds_val.filter(lambda x: x["id"] < 2000)
        return ds_train.shuffle(seed=42), ds_val.shuffle(seed=42) 
    if dataset == 'text_to_code':
        ds_train = load_dataset("google/code_x_glue_tc_text_to_code", split='train')
        ds_val = load_dataset("google/code_x_glue_tc_text_to_code", split='validation')
        return ds_train, ds_val
    if dataset == 'defect_detection':
        ds_train = load_dataset("google/code_x_glue_cc_defect_detection", split='train')
        ds_val = load_dataset("google/code_x_glue_cc_defect_detection", split='validation')
        return ds_train, ds_val


    ds_train = load_dataset("nyu-mll/glue", dataset, split='train')
    if dataset == 'mnli':
        ds_val1 = load_dataset("nyu-mll/glue", dataset, split='validation_matched')
        ds_val2 = load_dataset("nyu-mll/glue", dataset, split='validation_mismatched')
        ds_val = concatenate_datasets([ds_val1,ds_val2])
        # ds_val = ds_val.filter(lambda x: x["idx"] % 50 == 0)
    else:
        ds_val = load_dataset("nyu-mll/glue", dataset, split='validation')
    return ds_train, ds_val

def get_texts(dataset):
    ds_train = load_dataset("nyu-mll/glue", dataset, split='train')
    prompt = '''{}`{}`{}'''
    texts = get_processFunc(dataset)(ds_train,prompt)
    res = []
    for i in texts:
        af = i.split('`')
        res+=[{
            "instruction": af[0],
            "input": af[1],
            "output": af[2]

        }]
    return res

import json
if __name__ == '__main__':
    # make dataset in alpaca format for llama-factory
    {
    "instruction": "human instruction (required)",
    "input": "human input (optional)",
    "output": "model response (required)",

  }
    for dataset in ['cola','sst2','rte','qnli','wnli','mnli','qqp','mrpc']:
        res = []
        ds_train = load_dataset("nyu-mll/glue", dataset, split='train')
        prompt = '''{}`{}`{}'''
        texts = get_processFunc(dataset)(ds_train,prompt)
        for i in texts:
            print(i)
            af = i.split('`')
            res+=[{
                "instruction": af[0],
                "input": af[1],
                "output": af[2]

            }]
        os.makedirs('jsons',exist_ok=True)
        with open('jsons/alpaca_{dataset}.json','w+') as f:
            json.dump(res,f)