import os
from functools import partial
import json
import torch
from transformers import LogitsProcessorList
from .processor import kgw, watme, unigram, exp, sweet, morphmark, upv
from .processor import KgwDelta, UniDelta, UpvDelta
from .processor import patternbase, patternlen
from water_utils import load_model,load_data
from tqdm import tqdm

processor_dict = {
    "kgw": kgw.KGWWatermarkLogitsProcessor,
    "KgwDelta": KgwDelta.KGWWatermarkLogitsProcessor,
    "watme": watme.WatMEWatermarkLogitsProcessor,
    "unigram": unigram.UnigramLogitsProcessor,
    "UniDelta": UniDelta.UnigramLogitsProcessor,
    "sweet": sweet.SWEETWatermarkLogitsProcessor,
    "morphmark": morphmark.MorphmarkWatermarkLogitsProcessor,
    "upv": upv.UPVWatermarkLogitsProcessor,
    "UpvDelta": UpvDelta.UPVWatermarkLogitsProcessor,
    "patternbase": patternbase.PatternBaseWatermarkLogitsProcessor,
    "patternlen": patternlen.PatternLengthWatermarkLogitsProcessor,
}


def load_logit_processor(args, tokenizer):
    watermark_processor = processor_dict[args.water_mode]
    watermark_processor = watermark_processor(tokenizer=tokenizer, args=args)

    return watermark_processor


def set_model(args, model, tokenizer):
    if (args.water_mode != "wo_watermark") and (args.water_mode != "exp"):
        watermark_processor = load_logit_processor(args,tokenizer)
    gen_kwargs = dict(max_new_tokens=args.max_new_tokens, min_new_tokens=50)

    if args.topp is not None:
        print("topp :", args.topp)
        gen_kwargs.update(dict(
            do_sample=True, 
            temperature=0.7,
            top_p=0.95,
        ))
    else:
        gen_kwargs.update(dict(
            num_beams=args.n_beams
        ))
    

    if args.water_mode == "wo_watermark":
        gen_model = partial(
            model.generate,
            pad_token_id=tokenizer.eos_token_id,
            **gen_kwargs
            )
    elif args.water_mode == "exp":
        gen_model = partial(
            model,
            pad_token_id=tokenizer.eos_token_id,
            **gen_kwargs
            )
    else:
        gen_model = partial(
            model.generate,
            logits_processor=LogitsProcessorList([watermark_processor]),
            pad_token_id=tokenizer.eos_token_id,
            **gen_kwargs
            )
    return gen_model


def generate(prompts, args, model, device, tokenizer):
    if type(prompts) != list:
        prompts = [prompts]
        
    generated_text_list = []
    generated_ids_list = []
    generate_model = set_model(args, model, tokenizer)
    for prom in tqdm(prompts):
        tokd_input = tokenizer(prom, return_tensors="pt", add_special_tokens=True).to(device)

        torch.manual_seed(args.generation_seed)
        output = generate_model(**tokd_input)

        output = output[:,tokd_input["input_ids"].shape[-1]:]
        generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
        generated_ids = output.to('cpu')[0].tolist()

        generated_text_list.append(generated_text)
        generated_ids_list.append(generated_ids)

    return generated_text_list, generated_ids_list

def exp_generate(prompts, args, model, device, tokenizer):
    if type(prompts) != list:
        prompts = [prompts]
    
    generated_text_list = []
    generated_ids_list = []
    generate_model = set_model(args, model, tokenizer)

    gen_class = exp.EXPGenerate(generate_model, tokenizer, device, args)
    for prom in tqdm(prompts):
        generated_text, watermarked_ids = gen_class.generate_watermarked_text(prom)
        generated_ids = watermarked_ids.to('cpu').tolist()

        generated_text_list.append(generated_text)
        generated_ids_list.append(generated_ids)

    return generated_text_list, generated_ids_list

def save_results(prompts,text_list,ids_list,args):
    base_path = f"./results/gen_text/{args.dataset_name}/{args.model_name}_{args.model_size}/"

    key_elements = []
    for attr in ['water_mode', 'gamma', 'delta', 'alpha', 'topp', 'token_cut_length']:
        if hasattr(args, attr):
            value = getattr(args, attr)
            if value is not None:
                key_elements.append(str(attr.replace("_","")+"-"+str(value)))
    key_name = "_".join(key_elements)

    os.makedirs(base_path, exist_ok=True)
    prom_path = os.path.join(base_path,"prompt.json")
    if not os.path.exists(prom_path):
        with open(prom_path, 'w', encoding='utf-8') as f:
            json.dump(prompts,f,ensure_ascii=False,indent=2)

    if args.water_mode == "wo_watermark":
        gen_path = os.path.join(base_path,f"wo_watermark.json")
    else:
        gen_path = os.path.join(base_path,f"{key_name}.json")
    with open(gen_path, 'w', encoding='utf-8') as f:
        json.dump(text_list,f,ensure_ascii=False,indent=2)


    base_path = base_path.replace('gen_text','gen_ids')
    os.makedirs(base_path, exist_ok=True)

    if args.water_mode == "wo_watermark":
        gen_path = os.path.join(base_path,f"wo_watermark.json")
    else:
        gen_path = os.path.join(base_path,f"{key_name}.json")
    with open(gen_path, 'w', encoding='utf-8') as f:
        json.dump(ids_list,f)


def run_watermark(args): 
    model, tokenizer, device = load_model(args)

    prompts = load_data(args, tokenizer)

    if args.water_mode=="exp":
        text_list, ids_list = exp_generate(prompts, 
                                            args, 
                                            model=model, 
                                            device=device, 
                                            tokenizer=tokenizer)
    else:
        text_list, ids_list = generate(prompts, 
                                        args, 
                                        model=model, 
                                        device=device, 
                                        tokenizer=tokenizer)

    save_results(prompts, text_list, ids_list, args)

    return