"""
Prepares dataset for hindsight relabeling.
"""
import itertools
import pandas as pd
import difflib
import numpy as np
import random
random.seed(42)
np.random.seed(42)


def get_minimal_diff(code1, code2, return_lines: bool = False) -> str:
    diff = difflib.unified_diff(
        code1.splitlines(keepends=True), code2.splitlines(keepends=True), n=0
    )
    meta_symbols = set(["---", "+++", "@@"])
    diff_minus_meta = []
    has_meta = False
    for line in diff:
        for meta_symb in meta_symbols:
            if meta_symb in line:
                has_meta = True
                break

        if not has_meta:
            diff_minus_meta.append(line.strip())
        has_meta = False
    
    if return_lines:
        return diff_minus_meta

    return "\n".join(diff_minus_meta)


example_type_to_instruction = {
"fast": "# Optimize the above program for faster performance",
"slow": "# Introduce delays or reduce efficiency in the above program",
"bugfix": "# Identify and fix any bugs present in the above program",
"bugadd": "# Intentionally introduce bugs into the above program",
"copy": "# Create an exact copy of the above program",
}

def make_pair(a, b):
    if len(a['code']) == 0 or len(b['code']) == 0:
        return []
    if a['acc'] < 1 or b['acc'] < 1:
        correct, incorrect = (a, b) if a['acc'] == 1 else (b, a)
        return _make_acc_example(correct, incorrect)
    assert a['acc'] == 1 and b['acc'] == 1
    if is_2_faster_than_1(a['time'], b['time']):
        faster, slower = b, a
    elif is_2_faster_than_1(b['time'], a['time']):
        faster, slower = a, b
    else:
        return []
    return _make_speedup_example(faster, slower)



def is_2_faster_than_1(time_1, time_2, speedup: float = 1.25) -> bool:
    # faster if time_1 / time_2 > speedup
    return time_1 / time_2 > speedup


def _make_acc_example(correct, incorrect):
    bugadd = (correct, example_type_to_instruction["bugadd"], incorrect)
    bugfix = (incorrect, example_type_to_instruction["bugfix"], correct)
    return [format_input_output(*eg) for eg in [bugadd, bugfix]]
    

def _make_speedup_example(faster, slower):

    slow_to_fast = (slower, example_type_to_instruction["fast"], faster)
    fast_to_slow = (faster, example_type_to_instruction["slow"], slower)
    return [format_input_output(*eg) for eg in [slow_to_fast, fast_to_slow]]


def format_input_output(input, instruction, target):
    return {
        "input": input["code"] + "\n\n" + instruction + "\n\n",
        "target": target["code"],
        "input_time": input["time"],
        "target_time": target["time"],
        "input_acc": input["acc"],
        "target_acc": target["acc"],
        "instruction": instruction,
        "input_type": input["type"],
        "target_type": target["type"],
    }

def run(path: str) -> str:
    df = pd.read_json(path, orient="records", lines=True)
    df = df[(df['input_acc'] == 1.0) & (df['reference_acc'] == 1.0)]
    
    examples = []
    for i, row in df.iterrows():
        input_eg = {"code": row['input'], "acc": row['input_acc'], "time": row['input_time_mean'], "type": "input"}
        
        ref_eg = {"code": row['target'], "acc": row['reference_acc'], "time": row['reference_time_mean'], "type": "reference"}
        
        gen_eg = {"code": row['greedy_generated_target_from_input'], "acc": row['greedy_generated_target_from_input_acc'], "time": row['greedy_generated_target_from_input_time_mean'], "type": "generated"}
        
        meta = {"user_id": row['user_id'], "problem_id": row['problem_id'], "submission_id_v0": row['submission_id_v0'], "submission_id_v1": row['submission_id_v1']}
        
        if len(input_eg['code']) < 10 or len(ref_eg['code']) < 10 or len(gen_eg['code']) < 10:
            continue
        # iterate over all pairs
        examples_for_row = []
        for pair in itertools.combinations([input_eg, ref_eg, gen_eg], 2):
            examples_for_row.extend(make_pair(*pair))
        
        
        if len(get_minimal_diff(input_eg['code'], gen_eg['code'])) == 0:
            # add copy example
            examples_for_row.append(format_input_output(input_eg, example_type_to_instruction["copy"], gen_eg))
        
        # with 10% probability, add copy example
        if np.random.rand() < 0.1:
            eg_to_use = np.random.choice([input_eg, ref_eg, gen_eg])
            examples_for_row.append(format_input_output(eg_to_use, example_type_to_instruction["copy"], eg_to_use))
            
        examples.extend([{**eg, **meta} for eg in examples_for_row])
    
    return pd.DataFrame(examples)


def fast_ins_format(path):
    df = pd.read_json(path, orient="records", lines=True)
    df['input'] = df['input'].apply(lambda x: x + "\n\n" + example_type_to_instruction["fast"] + "\n\n")
    return df

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, required=True)
    parser.add_argument("--outdir", type=str, required=True)
    args = parser.parse_args()
    
    df_train = run(args.path)
    df_val = fast_ins_format("problem_id/2023-01-13_12-56pm/seq2seq_splits/val.jsonl")
    df_test = fast_ins_format("problem_id/2023-01-13_12-56pm/seq2seq_splits/test.jsonl")
    
    df_train.to_json(args.outdir + "/train.jsonl", orient="records", lines=True)
    df_val.to_json(args.outdir + "/val.jsonl", orient="records", lines=True)
    df_test.to_json(args.outdir + "/test.jsonl", orient="records", lines=True)
    
    
    
    
    
"""
export CUDA_VISIBLE_DEVICES=0 && export SRC_LEN=300 && export TGT_LEN=360 && export BSZ=16 && export MODEL_NAME="Salesforce/codegen-350M-mono" && export OUTDIR="problem_id/2023-01-13_12-56pm/seq2seq_splits_hir_1/" && nohup python -u seq2seq_pl/src/finetune.py --num_workers 32 --learning_rate 2e-5 --gpus 1 --do_train --do_predict --warmup_prop 0.05 --n_val 10000 --default_root_dir lightning_logs/codenet_py_using_codegen_cpp --val_check_interval 0.2 --data_dir ${OUTDIR} --train_batch_size ${BSZ} --eval_batch_size ${BSZ} --output_dir ${OUTDIR}/outputs/ --max_source_length ${SRC_LEN} --max_target_length ${TGT_LEN} --val_max_target_length ${TGT_LEN} --test_max_target_length ${TGT_LEN} --model_name_or_path ${MODEL_NAME} --seed 10708 --val_metric loss --save_top_k 1 --config_name ${MODEL_NAME} --mode next-token-prediction --src_key input --tgt_key target --add_lr_scheduler --logger_name wandb_shared --accumulate_grad_batches 2 --accelerator gpu --precision 16 --max_steps 30000 --gradient_clip_val 1.0 > logs/codenet_py_350m.txt  2>&1 &    
"""