import json
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import LogitsProcessor, LogitsProcessorList
from watermark_alg import WatermarkLogitsProcessor, WatermarkContext
import argparse
import os
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from utils.config import MODEL_DIR, DATA_DIR
import random
from transformers import TopKLogitsWarper, TemperatureLogitsWarper

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_path = os.path.join(MODEL_DIR, args.base_model)
    
    if args.base_model == 'gpt2':
        model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
        tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    elif args.base_model == 'llama':
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map = 'auto')
        tokenizer = LlamaTokenizer.from_pretrained(model_path)
    elif args.base_model == 'opt13b' or args.base_model == 'opt27b':
        model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

    watermark_model = WatermarkContext(device, args.chunk_size, tokenizer, delta = args.delta,transform_model_path=args.transform_model, mapping_file=args.mapping_file, embedding_model=args.embedding_model)
    logits_processor = WatermarkLogitsProcessor(watermark_model)

    with open(os.path.join(DATA_DIR, args.data_path), 'r') as f:
        lines = f.readlines()

    output = []
    torch.manual_seed(0)
    for line in lines:
        data = json.loads(line)
        text = data['text']
        words = text.split()

        if len(words) < 200 or len(words)>400:
            continue
        
        words = words[:30]
        begin_text = ' '.join(words)
        
        input_ids = tokenizer.encode(begin_text, return_tensors="pt").to(device)
        with torch.no_grad():
            if args.decode_method == "sample":
                outputs = model.generate(input_ids, logits_processor=LogitsProcessorList([logits_processor]), max_new_tokens=300, do_sample=True, no_repeat_ngram_size=4)
            elif args.decode_method == "beam":
                print('beam')
                outputs = model.generate(input_ids, max_new_tokens=300, num_beams=5, do_sample=False, no_repeat_ngram_size=4, logits_processor=LogitsProcessorList([ logits_processor]))    

            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            z_score_generated = watermark_model.detect(generated_text) if watermark_model else 0
            z_score_origin = watermark_model.detect(text) if watermark_model else 0

        if len(outputs[0]) > 150:
            output.append({
                'original_text': text, 
                'generated_text': generated_text,
                'z_score_origin': z_score_origin,
                'z_score_generated': z_score_generated
                })

        print(len(output))
        if len(output) >= args.generate_number:
            break

    with open(os.path.join(DATA_DIR, args.output_path), 'w' ) as f:
        json.dump(output, f, indent=4, ensure_ascii=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate text using watermark model')
    parser.add_argument('--watermark_type', type=str, default="window")
    parser.add_argument('--base_model', type=str, default="gpt2")
    parser.add_argument('--window_size', type=int, default=0)
    parser.add_argument('--generate_number', type=int, default=100)
    parser.add_argument('--delta', type=float, default=1)
    parser.add_argument('--chunk_size', type=int, default=10)
    parser.add_argument('--data_path', type=str, default="dataset/c4_train_sample.jsonl")
    parser.add_argument('--output_path', type=str, default="text_gpt2_top10.json")
    parser.add_argument('--transform_model', type=str, default="transform_model_cbert6.pth")
    parser.add_argument('--embedding_model', type=str, default="c-bert")
    parser.add_argument('--mapping_file', type=str, default="mapping/300_mapping_gpt2_2.json")
    parser.add_argument('--decode_method', type=str, default="sample")

    args = parser.parse_args()
    main(args)
