import json
import subprocess

results = {
    "rte": {
        '[12, 58, 49, 52, 26, 27, 26, 29, 66, 18, 5]_[0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0]': [0.5, 0.1, 10.0],
        '[12, 58, 49, 52, 26, 27, 26, 67, 67, 18, 5]_[0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0]': [0.1, 0.1, 9.0],
        '[13, 23, 20, 53, 15, 21, 36, 25, 37, 65, 5]_[2, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]': [0.1, 0.5, 15.0],
        '[34, 24, 57, 17, 57, 66, 58, 56, 68, 34, 55]_[1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1]': [0.1, 0.1, 5.0],
        '[34, 39, 2, 30, 32, 8, 52, 29, 44, 39, 13]_[2, 2, 1, 1, 1, 1, 1, 0, 2, 0, 2]': [0.1, 0.1, 8.0],
        '[55, 24, 62, 56, 57, 66, 59, 67, 68, 64, 57]_[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1]': [0.2, 1.0, 5.0],
    }, 
    "mrpc": {
        '[4, 4, 2, 7, 25, 2, 57, 4, 31, 38, 36]_[1, 1, 2, 1, 2, 1, 1, 0, 0, 0, 1]': [0.5, 0.1, 20.0],
        '[22, 3, 2, 42, 27, 7, 16, 49, 56, 38, 39]_[1, 1, 1, 1, 0, 1, 1, 2, 0, 1, 1]': [0.1, 8.0, 9.0],
        '[22, 3, 2, 42, 27, 7, 17, 49, 56, 38, 40]_[1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1]': [0.5, 3.0, 1.0],
        '[22, 3, 3, 42, 27, 61, 61, 66, 56, 38, 40]_[1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1]': [0.1, 0.5, 1.0],
        '[23, 31, 4, 4, 21, 64, 37, 45, 37, 12, 45]_[1, 2, 1, 1, 0, 1, 1, 0, 1, 0, 1]': [0.5, 3.0, 1.0],
        '[25, 21, 47, 23, 15, 51, 67, 67, 37, 35, 62]_[1, 1, 1, 1, 2, 1, 2, 1, 0, 0, 1]': [0.5, 3.0, 8.0],
        '[45, 31, 4, 23, 15, 5, 62, 49, 38, 13, 61]_[1, 1, 1, 1, 2, 1, 2, 0, 1, 0, 1]': [0.1, 10.0, 8.0],
        '[54, 57, 3, 14, 54, 59, 61, 1, 40, 19, 34]_[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]': [0.5, 0.5, 1.0],
        '[54, 57, 3, 14, 54, 59, 61, 1, 41, 5, 34]_[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1]': [0.5, 3.0, 12.0],
        '[54, 57, 3, 14, 54, 59, 61, 66, 35, 5, 63]_[1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1]': [0.5, 3.0, 1.0],
    }, 
    "stsb": {
        '[7, 27, 63, 42, 36, 31, 44, 4, 44, 2, 24]_[1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0]': [0.1, 0.01, 0.1],
        '[8, 27, 63, 42, 36, 31, 44, 6, 44, 14, 23]_[1, 2, 1, 2, 1, 0, 0, 0, 0, 0, 1]': [0.03, 0.01, 0.1],
        '[9, 40, 13, 42, 35, 33, 63, 6, 7, 28, 64]_[1, 1, 1, 2, 2, 0, 0, 0, 0, 0, 0]': [0.01, 0.09, 0.1],
        '[10, 9, 20, 1, 15, 59, 21, 14, 44, 65, 48]_[1, 2, 0, 1, 2, 0, 0, 0, 0, 0, 2]': [0.03, 0.01, 0.1],
        '[14, 46, 34, 28, 56, 13, 46, 43, 23, 26, 53]_[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]': [0.1, 0.3, 0.1],
        '[15, 63, 33, 42, 35, 33, 63, 7, 7, 28, 64]_[1, 1, 1, 2, 2, 0, 0, 0, 0, 0, 0]': [0.03, 0.03, 0.1],
        '[15, 63, 33, 42, 56, 32, 63, 44, 7, 28, 64]_[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]': [0.03, 0.03, 0.1],
        '[19, 1, 58, 45, 35, 44, 22, 58, 64, 57, 4]_[1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2]': [0.02, 0.01, 0.1],
        '[48, 51, 11, 6, 20, 48, 50, 3, 9, 27, 26]_[1, 1, 1, 1, 0, 0, 0, 2, 0, 0, 0]': [0.03, 0.01, 0.1],
        '[54, 7, 20, 1, 15, 48, 52, 14, 67, 7, 39]_[1, 2, 2, 1, 1, 0, 0, 0, 2, 0, 2]': [0.02, 0.03, 0.1],
    }, 
    "cola": {
        '[1, 24, 7, 13, 55, 65, 19, 40, 2, 12, 18]_[0, 0, 2, 0, 1, 1, 1, 0, 0, 0, 0]': [0.6, 1.0, 10.0],
        '[4, 16, 17, 15, 21, 3, 49, 33, 53, 24, 2]_[1, 2, 2, 0, 1, 1, 0, 0, 1, 0, 1]': [0.6, 1.0, 5.0],
        '[4, 18, 6, 13, 21, 6, 5, 33, 49, 24, 38]_[1, 2, 2, 1, 1, 1, 0, 0, 1, 0, 1]': [0.6, 1.0, 10.0],
        '[5, 25, 63, 11, 11, 8, 11, 42, 4, 2, 18]_[0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]': [0.1, 10.0, 20.0],
        '[5, 26, 63, 11, 10, 64, 11, 41, 1, 67, 24]_[0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]': [0.1, 10.0, 20.0],
        '[11, 18, 19, 64, 22, 53, 50, 46, 3, 1, 7]_[2, 0, 1, 2, 0, 1, 2, 1, 0, 0, 0]': [0.5, 1.0, 1.0] ,
        '[11, 18, 19, 64, 22, 53, 50, 46, 3, 1, 21]_[2, 0, 1, 2, 1, 1, 0, 1, 0, 0, 0]': [0.5, 1.0, 15.0],
        '[13, 18, 17, 8, 4, 3, 50, 31, 3, 1, 4]_[1, 0, 1, 0, 0, 1, 2, 1, 0, 0, 0]': [0.5, 1.0, 5.0],
        '[13, 18, 17, 8, 4, 44, 5, 31, 1, 1, 40]_[1, 0, 1, 1, 0, 1, 2, 1, 0, 0, 0]': [0.4, 1.0, 20.0] ,
        '[23, 18, 4, 15, 23, 5, 50, 52, 2, 5, 4]_[1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0]' : [0.5, 1.0, 20.0],
        '[42, 18, 18, 62, 21, 53, 47, 46, 1, 1, 21]_[2, 0, 1, 2, 1, 1, 0, 1, 0, 0, 0]': [0.5, 1.0, 1.0],
        '[45, 23, 15, 17, 12, 64, 11, 41, 15, 68, 38]_[0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0]': [0.5, 1.0, 1.0], 
    }
}

def replace_terms_in_script(script_filepath : str, replacements : dict) -> None:
    """Replaces terms in a script given the script file path and its replacements.

    Args:
        script_filepath (str): filepath of the script
        replacements (dict): replacements in dictionary form, where {"old_term" : "new_term"}
    """
    # Read the contents of the script file
    with open(script_filepath, 'r') as file:
        contents = file.read()

    # Perform the replacements
    for old_term, new_term in replacements.items():
        contents = contents.replace(old_term, new_term)

    # Write the modified script back to the file
    with open(script_filepath, 'w') as file:
        file.write(contents)
        
def revert_replace_terms_in_script(script_filepath : str, replacements : dict) -> None:
    """Replaces terms in a script given the script file path and its replacements.

    Args:
        script_filepath (str): filepath of the script
        replacements (dict): replacements in dictionary form, where {"old_term" : "new_term"}
    """
    # Read the contents of the script file
    with open(script_filepath, 'r') as file:
        contents = file.read()

    # Perform the replacements
    for old_term, new_term in replacements.items():
        contents = contents.replace(new_term, old_term)

    # Write the modified script back to the file
    with open(script_filepath, 'w') as file:
        file.write(contents)

def get_metric_name(task : str):
    if task == "stsb":
        return "eval_pearson"
    elif task == "cola":
        return "eval_matthews_correlation"
    else:
        return "eval_accuracy"

def run_shell_with_task(script_filepath : str, task : str):
    """Runs a shell script with task name.

    Args:
        script_filepath (str): Filepath of shell script. 

    Returns:
        _type_: Evaluation Accuracy 
    """
    output = subprocess.check_output([script_filepath], universal_newlines=True)
    
    # Split the output by lines
    lines = output.strip().split('\n')

    # Get the last line
    last_line = lines[-1]
    
    last_line = last_line.replace("'", "\"")

    # Parse the last line as a dictionary
    data = json.loads(last_line)
    
    metric_name = get_metric_name(task)

    return data[metric_name]

def inverted_computational_complexity_ratio(hidden_states, bits):
    score = 0
    num_states_left = 768
    for i in range(11):
        num_states_left -= hidden_states[i]
        score += (i + 1) * num_states_left * bits[i]
    return 1165824/score

if __name__ == "__main__":
    filename = "replicate.txt"
    headers = 'task,InvertedComputationalComplexityRatio,score,alpha,beta,temperature'

    with open(filename, 'w') as file:
        file.write(headers + '\n')

    for task in results.keys():
        models = results[task]
        for model, hyperparameters in models.items():
            parts = model.split('_')
            states = list(map(int, parts[0].strip('[]').split(', ')))
            bits_index = list(map(int, parts[1].strip('[]').split(', ')))
            alpha, beta, temperature = hyperparameters

            bits_options = [4, 8, 16, 32]
            bits = [bits_options[int(i)] for i in bits_index]

            sh = './glue_p_bert.sh'
            py = './transformers/models/bert/modeling_p_bert.py'

            sh_replacements = {
                'TASK="task"': f'TASK="{task}"',
                'ALPHA="0.0"': f'ALPHA="{alpha}"',
                'BETA="0.0"': f'BETA="{beta}"',
                'TEMPERATURE="0.0"': f'TEMPERATURE="{temperature}"',
                'STATES="0"': f'STATES="{states}"',
                'BITS="0"': f'BITS="{bits}"',
            }

            py_replacements = {
                'TASK="task"': f'TASK="{task}"',
                'STATES=0': f'STATES={states}',
                'BITS=0': f'BITS={bits}',
            }

            replace_terms_in_script(sh, sh_replacements)
            replace_terms_in_script(py, py_replacements)

            score = run_shell_with_task(sh, task)

            revert_replace_terms_in_script(sh, sh_replacements)
            revert_replace_terms_in_script(py, py_replacements)

            iccr = inverted_computational_complexity_ratio(states, bits)

            row = [task, iccr, score, alpha, beta, temperature]

            with open(filename, 'a') as file:
                file.write(','.join(map(str, row)) + '\n')