import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from argparse import ArgumentParser
import random
from tqdm import tqdm
import transformers

parser = ArgumentParser()
parser.add_argument('--gpus', type=str, default='0,1')
args = parser.parse_args()

gpus = args.gpus
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
model_path = 'meta-llama/Meta-Llama-3-8B-Instruct'
pipeline = transformers.pipeline(
    'text-generation',
    model=model_path,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_path)


def construct_length(text, max_length=1024):
    text = tokenizer.encode(text)
    text = text[:max_length]
    text = tokenizer.decode(text, skip_special_tokens=True)
    return text[:max_length]


@torch.no_grad()
def get_reply(prompt):
    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    messages = [
        {"role": "user", "content": prompt},
    ]
    outputs = pipeline(
        messages,
        max_new_tokens=1000,
        eos_token_id=terminators,
        do_sample=False,
        pad_token_id=pipeline.tokenizer.eos_token_id
    )
    return outputs[0]["generated_text"][-1]['content']


def main():
    prompts = {
        'paraphrase': 'Given a passage, please paraphrase it. The content should be the same. The passage is: {}',
        'rewriting': 'Given a passage, Please rewrite it to make it more convincing. The content should be the same. '
                     'The style should be serious, calm and informative. The passage is: {}',
        'open-ended': 'Given a sentence, please write a piece of news. The sentence is: {}'
    }
    datasets = os.listdir('../datasets/original')
    for dataset in datasets:
        dataset = dataset.replace('.json', '')
        data = json.load(open(f'../datasets/original/{dataset}.json'))
        pbar = tqdm(total=3 * len(data), leave=False)
        pbar.set_description_str(dataset)
        for way in ['paraphrase', 'rewriting', 'open-ended']:
            save_path = f'../datasets/llm-generation/{dataset}_{way}.json'
            if os.path.exists(save_path):
                continue
            res = []
            for item in data:
                claim = item['claim']
                claim = construct_length(claim)
                prompt = prompts[way].format(claim)
                new_claim = get_reply(prompt)
                res.append({
                    'claim': new_claim,
                    'label': item['label']
                })
                pbar.update()
            json.dump(res, open(save_path, 'w'))


if __name__ == '__main__':
    main()
