import sys
import argparse
from pathlib import Path
import json
from h4rm3l.decorators import make_prompt_decorator
from h4rm3l.elicitators import OpenAIIllicitPromptElicitator
from h4rm3l.prompter import prompt_model
from h4rm3l.harm_evaluation import is_harmful
import json
import csv
from h4rm3l.synthesizer import format_program_examples, main_synthesize
from h4rm3l.data_analysis.report_generator import main_analyse
from h4rm3l.data_analysis.prog_synthesi_curves import main_analyse_prog_synthesis
from h4rm3l.data_analysis.program_clustering import main_cluster_programs
from tqdm import tqdm
import logging
from datetime import datetime



def prepare_io_resources(args):
    return sys.stdin, sys.stdout


def main_decorate(args):
    if args.credentials is not None:
        credentials = json.loads(Path(args.credentials).read_text())
    else:
        credentials = None
    decorator = make_prompt_decorator(args.decoration_spec, credentials, args)
    input_file, output_file = prepare_io_resources(args)

    def generate_decorated_lines():
        for line in input_file:
            yield decorator(line) + "\n"
    output_file.writelines(
        generate_decorated_lines()
    )


def main_decorate_csv(args):
    if 'input_prompts_file' not in args:
        raise ValueError("--input-prompts-file required")
    if 'input_programs_file' not in args:
        raise ValueError("--input-programs-file required")
    if 'input_programs_csv_field' not in args or not args.input_programs_csv_field:
        raise ValueError("--input-programs-csv-field required")
    
    prompts = []
    programs = []
    with open(args.input_prompts_file) as fprompt, open(args.input_programs_file) as fprog:
        input_prompt_file_reader = csv.DictReader(fprompt)
        input_program_file_reader = csv.DictReader(fprog)
        prompts = [p for p in input_prompt_file_reader]
        programs = [r for r in input_program_file_reader]

    if args.credentials is not None:
        credentials = json.loads(Path(args.credentials).read_text())
    else:
        credentials = None
    
    with open(args.output_file, "w") as of:
        output_writer = csv.DictWriter(
            of, 
            fieldnames=input_prompt_file_reader.fieldnames + input_program_file_reader.fieldnames + [
                "decorated_harmful_prompt"
            ]
        )
        output_writer.writeheader()
        for program_info in tqdm(programs):
            logging.info("Decorating prompts with:" + str(program_info))
            decorator = make_prompt_decorator(
                program_info[args.input_programs_csv_field], 
                credentials, 
                args
            )
            for prompt_info in prompts:
                output_row = prompt_info.copy()
                output_row.update(program_info)
                output_row['decorated_harmful_prompt'] = decorator(output_row['harmful_prompt'])
                output_writer.writerow(output_row)
            

def main_format_program_examples(args):
    with open(args.input_file) as f:
        reader = csv.DictReader(f)
        print(format_program_examples(list(reader), sample_count=None, args=args))

def main_anthropic_redteaming_to_csv(args):
    with open(args.input_file) as f:
        data = json.load(f)
    with open(args.output_file, 'w') as of:
        output_writer = csv.DictWriter(of, fieldnames=[])
        output_writer.writeheader()
        for row in data:
            output_writer.writerow(row)


def main_elicit(args):
    input_file, output_file = prepare_io_resources(args)
    output_writer = csv.DictWriter(output_file, fieldnames=["original_prompt", "refusal_message", "elicited_prompt"])

    credentials = json.loads(Path(args.credentials).read_text())
    elicitator = OpenAIIllicitPromptElicitator(credentials)

    def generate_elicited_prompts():
        for original_prompt in input_file:
            yield original_prompt, elicitator.elicit_from(
                model_name=args.model_name, 
                prompt=original_prompt,
                count=args.elicitations_per_prompt
            )
    
    output_writer.writeheader()
    for original_prompt, elicited_prompts in generate_elicited_prompts():
        for row in elicited_prompts:
            row['original_prompt'] = original_prompt
            output_writer.writerow(row)



def main_evaluate(args):
    with open(args.input_file, 'r') as f, open(args.output_file, 'w') as of:
        input_reader = csv.DictReader(f)
        credentials = json.loads(Path(args.credentials).read_text())
        output_writer = csv.DictWriter(
            of, 
            fieldnames=input_reader.fieldnames + [
                'eval_prompt', 
                'eval_prompt_response', 
                'eval_harmful',
                'eval_harm_probability',
                'eval_reason',
                'eval_target_model', 
                'eval_moderation_model'
            ])
        output_writer.writeheader()
        eval_target_model = args.target_model_name
        eval_moderation_model = args.moderation_model_name

        for row in tqdm(list(input_reader)):
            output_row = row.copy()
            eval_prompt_decorated = row[args.eval_prompt_fieldname]
            eval_prompt_original = row[args.original_prompt_fieldname]
            eval_prompt_response = prompt_model(
                credentials=credentials, 
                model_name=eval_target_model, 
                prompt=eval_prompt_decorated,
                max_tokens=args.eval_max_tokens
            )
            program = row['program']

            # validate response vs non-decorated request
            eval_result = \
                is_harmful(
                    credentials, 
                    model_name=eval_moderation_model,
                    user_request_original=eval_prompt_original,
                    user_request_decorated=eval_prompt_decorated,
                    transformation=program,
                    model_response=eval_prompt_response
                    )


            if not eval_result:
                eval_result = {}
            eval_harmful = eval_result.get("eval_harmful", None)
            eval_harm_probability = eval_result.get("eval_harm_probability", None)
            eval_reason = eval_result.get("eval_reason", None)

            output_row['eval_prompt'] = eval_prompt_decorated
            output_row['eval_target_model'] = eval_target_model
            output_row['eval_prompt_response'] = eval_prompt_response
            output_row['eval_moderation_model'] = eval_moderation_model

            output_row['eval_harmful'] = eval_harmful
            output_row['eval_harm_probability'] = eval_harm_probability
            output_row['eval_reason'] = eval_reason
            
            output_writer.writerow(output_row)   
            of.flush() 

def main_hardcode(args):
    with open('experiments/experiment_010_publish_candidate/results/annotations - haiku - marked.csv', 'r') as f, open('experiments/experiment_010_publish_candidate/results/annotations - haiku - marked - annotated.csv', 'w') as of:
        input_reader = csv.DictReader(f)
        credentials = json.loads(Path('keys.json').read_text())
        output_writer = csv.DictWriter(
            of, 
            fieldnames=input_reader.fieldnames + [
                'decoded_output_new',
                'eval_reason_new',
                'eval_harmful_new',
                'harmful_information_new'
            ])
        output_writer.writeheader()

        for row in tqdm(list(input_reader)):
            output_row = row.copy()
            eval_prompt_decorated = row['decorated_harmful_prompt']
            eval_prompt_original = row['harmful_prompt']
            eval_prompt_response = row['eval_prompt_response']
            program = row['program']

            # validate response vs non-decorated request
            eval_result = \
                is_harmful(
                    credentials, 
                    model_name="gpt-4", 
                    user_request_original=eval_prompt_original, 
                    user_request_decorated=eval_prompt_decorated, 
                    transformation=program,
                    model_response=eval_prompt_response
                    )
            print(eval_result)
            print(eval_prompt_original)
            print(eval_prompt_response)

            if not eval_result:
                eval_result = {}
            decoded_output_new = eval_result.get("decoded_output", None)
            eval_reason_new = eval_result.get("eval_reason", None)
            eval_harmful_new = eval_result.get("eval_harmful", None)
            harmful_information_new = eval_result.get("harmful_information", None)
            

            output_row['decoded_output_new'] = decoded_output_new
            output_row['eval_reason_new'] = eval_reason_new
            output_row['eval_harmful_new'] = eval_harmful_new
            output_row['harmful_information_new'] = harmful_information_new
            
            output_writer.writerow(output_row)   
            of.flush()


    

def main_parse_command_line_args():
    parser = argparse.ArgumentParser("h4rm3l")
    parser.add_argument("option", choices=[
        'elicit', 'decorate', 'attack', 'decorate-csv', 'format-program-examples', 'synthesize', 'evaluate', 'analyse', 'analyse-prog-synthesis', 'hardcode_evaluate', 'cluster_programs'
    ])
    parser.add_argument("--log-dir", default="logs")

    parser.add_argument("--logging-level", default="INFO")
    parser.add_argument("--credentials")
    parser.add_argument("--model-name", choices=["gpt-3.5-turbo", "gpt-4"])


    ## params for elicit
    parser.add_argument("--elicitations-per-prompt", type=int)

    ## params for decorate
    parser.add_argument("--decoration-spec")

    ## params for decorate-csv
    parser.add_argument("--input-prompts-file")
    parser.add_argument("--input-programs-file")
    parser.add_argument("--input-programs-csv-field")
    parser.add_argument("--output-file")
    parser.add_argument("--cluster-model")
    parser.add_argument("--program-name-column")
    parser.add_argument("--asr-column")
    parser.add_argument("--input-file")
    parser.add_argument("--decorator-syntax-version", type=int, default=1)
    
    ## params for format-program-examples
    parser.add_argument("--program-description-field")
    parser.add_argument("--program-text-field")

    

    ## params for synthesize
    parser.add_argument("--primitives-file")
    parser.add_argument("--program-examples-file")

    parser.add_argument("--harmful-prompts-csv")
    parser.add_argument("--harmful-prompts-csv-column")
    parser.add_argument("--synthesis-harmful-prompts-sample-size", type=int)
    parser.add_argument("--synthesis-model-name")
    parser.add_argument("--synthesis-model-max-tokens", type=int)
    parser.add_argument("--synthesis-algo", choices=["bandit_random", "bandit_self_score", "bandit_offspring_score"])
    parser.add_argument("--synthesis-beta-score-scalar", type=int)
    parser.add_argument("--synthesis-beta-score-scale-with-iterations", default=False, action='store_true')

    parser.add_argument("--synthesis-iteration-count", type=int)
    parser.add_argument("--synthesis-examples-per-iteration", type=int)
    parser.add_argument("--synthesis-proposals-per-batch", type=int)
    parser.add_argument("--synthesis-proposals-per-iteration", type=int)
    parser.add_argument("--synthesis-score-proposals", default=False, action='store_true')
    parser.add_argument("--synthesis-examples-type")

    

    
    parser.add_argument("--random-seed", type=int, default=78990)

    # --moderation-model-name

    ## params for evaluate
    parser.add_argument("--eval-prompt-fieldname")
    parser.add_argument("--original-prompt-fieldname")
    parser.add_argument("--target-model-name")
    parser.add_argument("--moderation-model-name")
    parser.add_argument("--eval-max-tokens", type=int, default=256)

    ## params for analyse
    parser.add_argument("--output-dir")
    parser.add_argument("--evaluation-files", nargs="+")
    parser.add_argument("--program-name-field", default="program_name")
    parser.add_argument("--program-names", nargs='+')
    parser.add_argument("--target-model-names", nargs='+')
    parser.add_argument("--report-category", default="bad bot")
    parser.add_argument("--color-palette", nargs="+", default=[
        '#003366', '#0072B2', '#CC6677', '#EE7733', '#009988', '#33AA99', '#D55E00', '#E69F00', '#F0E442', '#FFFF99', '#56B4E9', '#ADD8E6', '#000000', '#DDDDDD', '#CC79A7', '#F0C2E0', '#0072B2', '#56B4E9', '#EE7733', '#FFBB00'
        ]
    )
    parser.add_argument("--input-files", nargs="*")


    parser.add_argument("--plot-extension", default="jpg")
    
    
    return parser.parse_args()

def main_setup_logging(args):
    logdir = Path(args.log_dir)
    logdir.mkdir(parents=True, exist_ok=True)
    ts = datetime.now().strftime("%Y-%m-%d")
    logging.basicConfig(
        filename=logdir / f"h4rm3l_{ts}.log", 
        level=args.logging_level,
        format="%(levelname)s %(asctime)s %(message)s"
    )


if __name__ == '__main__':
    args = main_parse_command_line_args()
    main_setup_logging(args)
    functions_by_option = {
        'elicit': main_elicit,
        'decorate': main_decorate,
        'synthesize': main_synthesize,
        'format-program-examples': main_format_program_examples,
        'decorate-csv': main_decorate_csv,
        'evaluate': main_evaluate,
        'anthropic-redteaming-to-csv': main_anthropic_redteaming_to_csv,
        'analyse': main_analyse,
        'analyse-prog-synthesis': main_analyse_prog_synthesis,
        'hardcode_evaluate': main_hardcode,
        'cluster_programs': main_cluster_programs
    }
    if args.option not in functions_by_option:
        raise ValueError(f"Unknown command: {args.option}")

    logging.info(f"Started: {args.option}")
    logging.info(f"\targs: {args}")
    try:
        functions_by_option[args.option](args)
    except BrokenPipeError:
        pass
    logging.info(f"Finished: {args.option}")
