from opto import trace
import subprocess
import re

from opto.trace.operators import call_llm

@trace.model
class SimpleDecryptorAgent:

   """
   This class is responsible for analyzing and processing the contents of `source.py`, 
   which implements a Linear Congruential Generator (LCG) and an encryption scheme. 
   
   The agent reads the contents of `source.py` and `output.txt`, storing them for further analysis, 
   likely to facilitate cryptanalysis or reverse engineering of the encryption process.

   Args:
      None

   Attributes:
      source_file (str): The contents of `source.py`, which includes the `LCG` and `Encryptor` classes.
      output (str): The contents of `output.txt`, which may contain encrypted data or logs.

   Returns:
      None (since this is an initializer method)

   Note:
      - The method reads files directly from the specified file paths.
      - Ensure that the files exist at the given paths; otherwise, this will raise a `FileNotFoundError`.
      - The files are read in their entirety and stored in memory as strings.
   """
   def __init__(self):
      self.source_file = open("/Users/anonymous/trace-security/security/crypto/slcg/resources/source.py").read()
      self.output = open("/Users/anonymous/trace-security/security/crypto/slcg/resources/output.txt").read()

   """
   The run method. This is the collective list of actions performed 
   by the agent in order to achieve the task provided.

   Task: In this case, the task is to analyze and attempt decryption of 
   data encrypted using the `Encryptor` class in `source.py`. The program reads 
   the source code and output logs, then compiles an initial report on their contents. 
   
   It attempts to reverse-engineer the encryption mechanism, which is based on 
   a Linear Congruential Generator (LCG), to retrieve the original plaintext.
   """
   def run(self):
      print("RUN METHOD INITIATED")
      analysis_report = self.analyze_source_file(self.source_file, self.output)
      return self.decrypt(analysis_report, self.source_file, self.output)
   
   
   """
   Analyzes the given source file for cryptographic weaknesses using an AI model.

   Args:
      source_file (str): The contents of the source code file to be analyzed.
      execution_output (str): The output produced by executing the source file, 
                              which may provide additional context for analysis.

   Returns:
      str: The response from the AI model containing an analysis of potential 
            vulnerabilities, including reasoning about weaknesses in the LCG-based encryption.

   Process:
      1. Prints a message indicating that the source file is being analyzed.
      2. Constructs a prompt instructing the AI to analyze vulnerabilities in the given source code.
      3. Calls the AI model with the prompt and the concatenated source code and execution output.
      4. Returns the AI-generated analysis.

   Note:
      - The effectiveness of the analysis depends on the AI model's ability to interpret
         the source code and execution output.
      - This specifically looks for flaws in the LCG-based encryption mechanism.
   """
   def analyze_source_file(self, source_file, execution_output):
      print("ANALYZING SOURCE FILE")
      analyze_prompt = trace.node("You are a cybersecurity agent, analyze the vunerability in this source file. Give me chain of thought.")
      return call_llm(analyze_prompt, source_file + '\n\nExecution output:' + execution_output)
   
   """
   Extracts ciphertext values from the given output file content.

   Args:
      output_file (str): The contents of the output file as a string.

   Returns:
      list[int]: A list of encrypted integer values extracted from `ct`.

   Process:
      1. Parses the output file to locate the `ct` variable.
      2. Extracts and processes the ciphertext values.
      3. Returns the extracted values as a list of integers.

   Note:
      - This method assumes that `ct` is formatted as a Python list in the output file.
      - If the `ct` variable is missing or formatted incorrectly, the method may fail.
   """
   def output_extraction(self, output_file):
        import re
        match = re.search(r'ct\s*=\s*\[(.*?)\]', output_file, re.DOTALL)
        if match:
            ct_values = match.group(1)
            return [int(num.strip()) for num in ct_values.split(',')]
        return []
   
   @trace.bundle(trainable=True)
   def decrypt(self, analysis_report, source_file, execution_output):
    """
    Attempts to decrypt an encrypted message using extracted LCG parameters.

    This function analyzes the execution output to extract LCG parameters and attempts 
    to reverse the encryption performed by the `Encryptor` class in `source.py`. 
    
    Args:
        analysis_report (str): The report generated from the analysis process, potentially 
                                containing metadata or insights relevant to encryption.
        source_file (str): The Python source file that contains the encryption logic.
        execution_output (str): The output generated during execution, from which 
                                encryption parameters are extracted.

    Returns:
        str: The decrypted flag if successful; otherwise, returns None.
        
    Process:
        1. Extracts the LCG parameters (modulus, multiplier, increment, seed).
        2. Attempts to reconstruct the random sequence used in encryption.
        3. Applies the inverse transformation to retrieve the original plaintext.

    Note:
        - The success of this function depends on the predictability of the LCG sequence.
        - Avoid modifying helper functions outside this method, as it may disrupt execution.
    """
    from resources import source
    from resources.source import LCG
    ct = self.output_extraction(execution_output)
    flag = None
    bin_char = list(map(int, list(f'{ord("g"):b}')))
   
    # get all values that we know comes from the LCG that
    # is used if the bit is 1
    
    # calculate the modulus, the multiple and the additive value
    def calc_mod(values):
        return None
    def calc_multiplicator(value1, value2):
        return None
    def calc_summand(value1, value2, value3):
        return None

    # create a variable for the known LCG and set it to none

    # Build a list of lists that hold the "ciphertext" for every bit

    flag_bits = []

    for encrypted_char_bit in encrypted_chars_bits:
        for encrypted_bit in encrypted_char_bit:
            # if it is the first block, then we need to create the LCG
            if not lcg:
                lcg = LCG(mod=mod, mult=mult, add=add, seed=encrypted_bit)
                # the first bit was 1, therefor we can set it
                flag_bits.append(1)
                continue

            # in case we are wrong, we need to restore the previous seed
            seed = lcg.value
            rng_next_val = next(lcg)

            # check if the next ciphertext value is the same as the value that
            # the LCG would have generated
            if rng_next_val == encrypted_bit:
                flag_bits.append(1)
            else:
                # if the values did not match, we know that the plaintext at
                # this position was a 0
                flag_bits.append(0)
                # restore the previous seed to not get out of sync with the
                # LCG used during encryption
                lcg.value = seed

        lcg = LCG(next(lcg), next(lcg), next(lcg), next(lcg))

    chars_bits = [flag_bits[i*7: (i+1)*7] for i in range(len(flag_bits)//7)]
    for char_bits in chars_bits:
        result = 0
        for bit in char_bits:
            result = result << 1 | bit
        print(chr(result), end="")

    print("\n", end="")
    return flag

agent = SimpleDecryptorAgent()
print(agent.parameters()[0].data)
agent.save("agent.pkl")
agent.load("agent.pkl")

predicted_flag = agent.run()
predicted_flag.backward(visualize=True)