import numpy as np
import subprocess
import argparse

from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.termination import get_termination
from pymoo.optimize import minimize
from pymoo.core.problem import ElementwiseProblem

def replace_terms_in_script(script_filepath : str, replacements : dict) -> None:
    """Replaces terms in a script given the script file path and its replacements.

    Args:
        script_filepath (str): filepath of the script
        replacements (dict): replacements in dictionary form, where {"old_term" : "new_term"}
    """
    # Read the contents of the script file
    with open(script_filepath, 'r') as file:
        contents = file.read()

    # Perform the replacements
    for old_term, new_term in replacements.items():
        contents = contents.replace(old_term, new_term)

    # Write the modified script back to the file
    with open(script_filepath, 'w') as file:
        file.write(contents)
        
def revert_replace_terms_in_script(script_filepath : str, replacements : dict) -> None:
    """Replaces terms in a script given the script file path and its replacements.

    Args:
        script_filepath (str): filepath of the script
        replacements (dict): replacements in dictionary form, where {"old_term" : "new_term"}
    """
    # Read the contents of the script file
    with open(script_filepath, 'r') as file:
        contents = file.read()

    # Perform the replacements
    for old_term, new_term in replacements.items():
        contents = contents.replace(new_term, old_term)

    # Write the modified script back to the file
    with open(script_filepath, 'w') as file:
        file.write(contents)
        
def run_shell(file):
    output = ""
    output = subprocess.check_output([file], universal_newlines=True)
    
    lines = output.strip().split('\n')
    last_line = lines[-1]
    last_line = last_line.replace("'", "\"")
    return last_line

def write_n_states(lst):
    file_name = './glue_obtain_hidden_states.sh'
    
    replacements = {
        'TASK="task"': f'TASK="{TASK}"', 
        'STATES="0"': f'STATES="{lst}"',
    }
    
    replace_terms_in_script(file_name, replacements)
    
    last_line = run_shell(file_name)

    revert_replace_terms_in_script(file_name, replacements)
    return last_line

def evaluate_model(lst, bits):
    sh = './glue_compression.sh'
    py = './transformers/models/bert/modeling_p_bert.py'
    
    sh_replacements = {
        'TASK="task"': f'TASK="{TASK}"', 
        'STATES="0"': f'STATES="{lst}"',
        'BITS="0"': f'BITS="{bits}"',
    }
    
    py_replacements = {
        'TASK="task"': f'TASK="{TASK}"', 
        'STATES=0': f'STATES={lst}',
        'BITS=0': f'BITS={bits}',
    }
    
    replace_terms_in_script(sh, sh_replacements)
    replace_terms_in_script(py, py_replacements)
    
    last_line = run_shell(sh)

    revert_replace_terms_in_script(sh, sh_replacements)
    revert_replace_terms_in_script(py, py_replacements)
    
    return float(last_line)

def calculate_inverted_computational_complexity_ratio(hidden_states, bits):
    score = 0
    num_states_left = 768
    for i in range(11):
        num_states_left -= hidden_states[i]
        score += (i + 1) * num_states_left * bits[i]
    return 1165824/score

class P_Problem(ElementwiseProblem):
    def __init__(self):
        super().__init__(n_var=22, n_obj=2, n_constr=0, xl=np.concatenate((np.ones(11), np.zeros(11))), xu=np.concatenate((np.full(11, 69), np.full(11, 3))))
        self.bits_options = [4, 8, 16, 32]

    def _evaluate(self, x, out, *args, **kwargs):
        int_array = x.astype(int)
        num_hidden_states_list = []
        bits = []
        for i in range(22):
            if i < 11:
                num_hidden_states_list.append(int_array[i])
            else: 
                bits.append(self.bits_options[int_array[i]])
        write_n_states(num_hidden_states_list)
        score = evaluate_model(num_hidden_states_list, bits) 
        out["F"] = [-score, -calculate_inverted_computational_complexity_ratio(num_hidden_states_list, bits)]
        
def main():
    # Configure the Optimization Algorithm
    algorithm = NSGA2(pop_size=INITIAL_POP_SIZE)

    # Configure Termination Criteria
    termination = get_termination("n_gen", NUM_OF_GENERATIONS) 

    # Run the Optimization
    problem = P_Problem()
    res = minimize(problem, algorithm, termination, save_history=True)
    
    with open('run.txt', 'a') as file:        
        file.write("######## Number of Hidden States and Quantization Bit Index of the Optimal Final Population ######## \n")
        file.write(np.array2string(res.X.astype(int)))
        file.write("\n")
        
        file.write("######## Optimal Final Population's [Evaluation Metric, Inverted Computational Complexity Ratio] ######## \n")
        arr = res.F
        file.write(np.array2string(arr))
        file.write("\n")

if __name__ == "__main__":  
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, required=True)
    parser.add_argument("--population", type=int, required=True)
    parser.add_argument("--generations", type=int, required=True)
    args = parser.parse_args()  
    
    TASK = args.task
    INITIAL_POP_SIZE = args.population
    NUM_OF_GENERATIONS = args.generations
    
    main()