import os
import json
import torch
from math import sqrt
from tqdm import tqdm
from .processor import kgw, watme, unigram, exp, sweet, morphmark, upv
from .processor import KgwDelta, UniDelta, UpvDelta
from .processor import patternbase, patternlen
from water_utils import load_model,load_data
from tqdm import tqdm

processor_dict = {
    "kgw": kgw.KGWWatermarkDetector,
    "KgwDelta": KgwDelta.KGWWatermarkDetector,
    "watme": watme.WatMEWatermarkDetector,
    "unigram": unigram.UnigramWatermarkDetector,
    "UniDelta": UniDelta.UnigramWatermarkDetector,
    "sweet": sweet.SWEETWatermarkDetector,
    "morphmark": morphmark.MorphmarkWatermarkDetector,
    "upv": upv.UPVWatermarkDetector,
    "UpvDelta": UpvDelta.UPVWatermarkDetector,
    "patternbase": patternbase.PatternBaseWatermarkDetector,
    "patternlen": patternlen.PatternLengthWatermarkDetector,
    "exp": exp.EXPDetect,
}


def load_detect_processor(args, tokenizer):
    detect_processor = processor_dict[args.water_mode]
    
    if args.water_mode == "exp":
        detect_processor = detect_processor(tokenizer=tokenizer, args=args, device="cuda")
    else:
        detect_processor = detect_processor(tokenizer=tokenizer, args=args, device="cuda")

    return detect_processor
    

def run_processor(ids, save_path, tokenizer, device, args, model):
    detect_processor = load_detect_processor(args, tokenizer)
    z_score_list = []
    for i in tqdm(ids):
        input_ids = torch.tensor(i, device=device)

        if len(input_ids)-1 > detect_processor.min_prefix_len:
            if args.water_mode == "sweet":
                out_dict = detect_processor.detect(tokenized_text=input_ids, model=model)
            else:
                out_dict = detect_processor.detect(tokenized_text=input_ids)
            z_score_list.append(out_dict['z_score'])
        else:
            print("Error","string too short to compute metrics")
            z_score_list.append(0.0)
        
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(z_score_list,f)


def run_detect(args):
    key_elements = []
    for attr in ['water_mode', 'gamma', 'delta', 'alpha', 'topp', 'token_cut_length']:
        if hasattr(args, attr):
            value = getattr(args, attr)
            if value is not None:
                key_elements.append(str(attr.replace("_","")+"-"+str(value)))
    key_name = "_".join(key_elements)

    attack_suffix = getattr(args, 'attack_suffix', None)
    if attack_suffix:
        base_path = f'./results/zscore_attacked/{attack_suffix}/{args.dataset_name}/{args.model_name}_{args.model_size}/'
        gen_ids_base = f"./results/gen_ids_attacked/{attack_suffix}/{args.dataset_name}/{args.model_name}_{args.model_size}/"
    else:
        base_path = f'./results/zscore/{args.dataset_name}/{args.model_name}_{args.model_size}/'
        gen_ids_base = f"./results/gen_ids/{args.dataset_name}/{args.model_name}_{args.model_size}/"
    
    os.makedirs(base_path,exist_ok=True)

    model, tokenizer, device = load_model(args)
    
    ids_path = f"{gen_ids_base}wo_watermark.json"
    if not os.path.exists(ids_path):
        raise FileNotFoundError(f"{ids_path} is not exists")
    with open(ids_path, 'r', encoding='utf-8') as f:
        ids = json.load(f)
    
    save_path = os.path.join(base_path,f'{key_name}_wo_watermark.json')
    if not os.path.exists(save_path):
        run_processor(ids, save_path, tokenizer, device, args, model)

    ids_path = f"{gen_ids_base}{key_name}.json"
    if not os.path.exists(ids_path):
        raise FileNotFoundError(f"{ids_path} is not exists")
    with open(ids_path, 'r', encoding='utf-8') as f:
        ids = json.load(f)

    save_path = f"{key_name}.json"
    save_path = os.path.join(base_path,save_path)
    print(ids_path)
    run_processor(ids, save_path, tokenizer, device, args, model)

    return