import os

def text_summarization_undetectable_exp(output_dir, reweight_type, skip_generation):
    from . import text_summarization as ts

    # output_dir = os.path.join(output_dir, reweight_type)
    os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, 'text_summarization.txt')
    score_save_path = os.path.join(output_dir, 'text_summarization_result.txt')
    ppl_save_path = os.path.join(output_dir, 'text_summarization_ppl.txt')
    
    if skip_generation == False:
        print("ts.get_output.undetectable_exp_pipeline()")
        ts.get_output.undetectable_exp_pipeline(output_path=output_path, reweight_type=reweight_type)

    print("ts.evaluate.undetectable_exp_pipeline()")
    ts.evaluate.undetectable_exp_pipeline(
        output_path=output_path, score_save_path=score_save_path
    )

    print("ts.evaluate_ppl.undetectable_exp_pipeline()")
    ts.evaluate_ppl.undetectable_exp_pipeline(
        output_path=output_path, ppl_save_path=ppl_save_path
    )

    print("finish text summarization.")


def text_generation_undetectable_exp(res_dir,
                                     eps,
                                     model_str,
                                     reweight_type,
                                     dataset_name,
                                     skip_generation=False,
                                     use_other_wp=False,
                                     other_reweight_type=None,
                                     skip_evaluation=False,
                                     suffix="",):
    from . import text_generation as tg
    
    assert eps>=0
    assert eps<=1
    
    
    sub_dir=os.path.join(res_dir,dataset_name,model_str.split('/')[-1].replace('-','_'),reweight_type)
    os.makedirs(sub_dir,exist_ok=True)
    
    output_path=os.path.join(sub_dir, f'text_generation{suffix}.txt')
    if eps==0:
        if use_other_wp:
            score_save_path=os.path.join(sub_dir, f'score{suffix}_evaluated_by_{other_reweight_type}.txt')
        else:
            score_save_path=os.path.join(sub_dir, f'score{suffix}.txt')
    else:
        assert os.path.exists(output_path)
        eps_str=str(eps).replace('.','_')
        score_save_path=os.path.join(sub_dir, f'eps_{eps_str}{suffix}_o.txt')
        
    if os.path.exists(score_save_path) and f'score{suffix}_evaluated_by_{other_reweight_type}.txt' not in score_save_path:
        print('Found exisiting score_save_path:')
        print(score_save_path)
        print('Job skipped.')
        return
        
        
    if eps==0 and skip_generation == False:
        print("tg.get_output.undetectable_exp_pipeline()",flush=True)
        tg.get_output.undetectable_exp_pipeline(output_path=output_path,
                                                model_str=model_str,
                                                reweight_type=reweight_type,
                                                dataset_name=dataset_name)

    if skip_evaluation == False:
        print("tg.evaluate_beta_score.pipeline()",flush=True)
        tg.evaluate_beta_score.pipeline(
            output_path=output_path, 
            score_save_path=score_save_path,
            eps=eps,
            model_str=model_str,
            dataset_name=dataset_name,
            use_other_wp=use_other_wp,
            reweight_type=other_reweight_type,
        )
    
    # print("tg.evaluate_ppl.undetectable_exp_pipeline()")
    # tg.evaluate_ppl.undetectable_exp_pipeline(output_path=output_path,ppl_save_path=ppl_save_path)
    
    print("finish text generation.")
    

def machine_translation_exp(output_dir, reweight_type, skip_generation):
    from . import machine_translation as mt

    # output_dir = os.path.join(output_dir, reweight_type)
    # os.makedirs(output_dir, exist_ok=True)
    
    output_path = os.path.join(output_dir, 'machine_translation.txt')
    score_save_path = os.path.join(output_dir, 'machine_translation_result.txt')
    ppl_save_path = os.path.join(output_dir, 'machine_translation_ppl.txt')
    bleu_save_path = os.path.join(output_dir, 'machine_translation_bleu.txt')
    
    if skip_generation == False:
        print("mt.get_output.pipeline()")
        mt.get_output.pipeline(output_path=output_path)
    
    print("mt.evaluate.pipeline")
    mt.evaluate.pipeline(output_path=output_path, score_save_path=score_save_path)
    
    print("mt.evaluate_ppl.pipeline")
    mt.evaluate_ppl.pipeline(output_path=output_path,ppl_save_path=ppl_save_path)
    
    
    print("mt.evaluate_bleu.compute_bleu")
    mt.evaluate_bleu.compute_bleu(output_path=output_path,bleu_save_path=bleu_save_path)
    
    
    print('finish machine translation')


def add_watermark_exp():
    import argparse

    parser = argparse.ArgumentParser()

    # ts: text summarization; tg: text generation
    parser.add_argument("--exp_type", type=str, choices=["ts", "tg","mt"])
    parser.add_argument('--model_str',type=str,
                        choices=['Qwen/Qwen2.5-3B-Instruct',
                                 'mistralai/Mistral-7B-Instruct-v0.3',
                                 'meta-llama/Llama-3.2-3B-Instruct',
                                 'meta-llama/Llama-2-7b-chat-hf',
                                 'microsoft/Phi-3.5-mini-instruct'],
                        help='Model path for text generation.')
    parser.add_argument('--dataset_name',type=str,help='Dataset name for text generation')
    
    
    parser.add_argument('--reweight_type',type=str,
                        choices=['ITS','EXP','baselines','nmark','test','main_exp','GumbelMax', 'SynthID', 'None'])

    
    parser.add_argument(
        "--ts_output_path", type=str, default="ts_results/text_summarization.txt"
    )
    parser.add_argument(
        "--ts_ppl_save_path", type=str, default="ts_results/text_summarization_ppl.txt"
    )
    parser.add_argument(
        "--ts_score_save_path",
        type=str,
        default="ts_results/text_summarization_result.txt",
    )
    
    # parser.add_argument(
    #     "--tg_output_path", type=str, default="tg_result_rebuttal_rephrase/text_generation_eps_0_3.txt"
    # )
    # parser.add_argument(
    #     "--tg_score_save_path", type=str, default="tg_result_rebuttal_rephrase/eps_0_3.txt"
    # )
    # parser.add_argument(
    #     "--tg_ppl_save_path", type=str, default="tg_result_rebuttal_robustness/text_generation_ppl.txt"
    # )
    
    # parser.add_argument(
    #     "--tg_eps",type=float,default=0.0
    # )
    
    parser.add_argument('--tg_res_dir',type=str)
    parser.add_argument("--tg_eps",type=float)
    
    
    parser.add_argument(
        "--mt_output_path", type=str, default="mt_result_baselines/machine_translation.txt"
    )
    parser.add_argument(
        "--mt_ppl_save_path", type=str, default="mt_result_baselines/machine_translation_ppl.txt"
    )
    parser.add_argument(
        "--mt_score_save_path",
        type=str,
        default="mt_result_baselines/machine_translation_result.txt",
    )
    parser.add_argument(
        "--mt_bleu_save_path",
        type=str,
        default="mt_result_baselines/machine_translation_bleu.txt"
    )

    parser.add_argument('--skip_generation', action='store_true', default=False, help='Skip generation')
    parser.add_argument('--use_other_wp', action='store_true', default=False, help='Skip generation')
    parser.add_argument('--other_reweight_type',type=str,
                        choices=['ITS','EXP','baselines','nmark','test','main_exp','GumbelMax', 'SynthID', 'None'], default=None)

    parser.add_argument('--skip_evaluation', action='store_true', default=False, help='Skip evaluation')
    parser.add_argument('--suffix', type=str, default="")

    args = parser.parse_args()

    if args.exp_type == "tg":
        text_generation_undetectable_exp(
            res_dir=args.tg_res_dir,
            eps=args.tg_eps,
            model_str=args.model_str,
            reweight_type=args.reweight_type,
            dataset_name=args.dataset_name,
            skip_generation=args.skip_generation,
            use_other_wp=args.use_other_wp,
            other_reweight_type=args.other_reweight_type,
            skip_evaluation=args.skip_evaluation,
            suffix=args.suffix,
        )
    elif args.exp_type == "ts":
        text_summarization_undetectable_exp(
            output_dir=args.ts_output_path,
            reweight_type=args.reweight_type,
            skip_generation=args.skip_generation,
        )
    elif args.exp_type=='mt':
        machine_translation_exp(
            output_dir=args.mt_output_path,
            reweight_type=args.reweight_type,
            skip_generation=args.skip_generation,
        )
    else:
        print("Unknown exp_type.")
        raise NotImplementedError


if __name__ == "__main__":
    print("Add watermark experiment")
    add_watermark_exp()
