from opto import trace
import subprocess
import re

from opto.trace.operators import call_llm

@trace.model
class SimpleDecryptorAgent:

   """
   This class is responsible for analyzing and processing the contents of `source.py`, 
   which implements a Linear Congruential Generator (LCG) and an encryption scheme. 
   
   The agent reads the contents of `source.py` and `output.txt`, storing them for further analysis, 
   likely to facilitate cryptanalysis or reverse engineering of the encryption process.

   Args:
      None

   Attributes:
      source_file (str): The contents of `source.py`, which includes the `LCG` and `Encryptor` classes.
      output (str): The contents of `output.txt`, which may contain encrypted data or logs.

   Returns:
      None (since this is an initializer method)

   Note:
      - The method reads files directly from the specified file paths.
      - Ensure that the files exist at the given paths; otherwise, this will raise a `FileNotFoundError`.
      - The files are read in their entirety and stored in memory as strings.
   """
   def __init__(self):
      self.source_file = open("/Users/anonymous/trace-security/security/crypto/slcg/resources/source.py").read()
      self.output = open("/Users/anonymous/trace-security/security/crypto/slcg/resources/output.txt").read()

   """
   The run method. This is the collective list of actions performed 
   by the agent in order to achieve the task provided.

   Task: In this case, the task is to analyze and attempt decryption of 
   data encrypted using the `Encryptor` class in `source.py`. The program reads 
   the source code and output logs, then compiles an initial report on their contents. 
   
   It attempts to reverse-engineer the encryption mechanism, which is based on 
   a Linear Congruential Generator (LCG), to retrieve the original plaintext.
   """
   def run(self):
      print("RUN METHOD INITIATED")
      analysis_report = self.analyze_source_file(self.source_file, self.output)
      return self.decrypt(analysis_report, self.source_file, self.output)
   
   
   """
   Analyzes the given source file for cryptographic weaknesses using an AI model.

   Args:
      source_file (str): The contents of the source code file to be analyzed.
      execution_output (str): The output produced by executing the source file, 
                              which may provide additional context for analysis.

   Returns:
      str: The response from the AI model containing an analysis of potential 
            vulnerabilities, including reasoning about weaknesses in the LCG-based encryption.

   Process:
      1. Prints a message indicating that the source file is being analyzed.
      2. Constructs a prompt instructing the AI to analyze vulnerabilities in the given source code.
      3. Calls the AI model with the prompt and the concatenated source code and execution output.
      4. Returns the AI-generated analysis.

   Note:
      - The effectiveness of the analysis depends on the AI model's ability to interpret
         the source code and execution output.
      - This specifically looks for flaws in the LCG-based encryption mechanism.
   """
   def analyze_source_file(self, source_file, execution_output):
      print("ANALYZING SOURCE FILE")
      analyze_prompt = trace.node("You are a cybersecurity agent, analyze the vunerability in this source file. Give me chain of thought.")
      return call_llm(analyze_prompt, source_file + '\n\nExecution output:' + execution_output)
   
   """
   Extracts ciphertext values from the given output file content.

   Args:
      output_file (str): The contents of the output file as a string.

   Returns:
      list[int]: A list of encrypted integer values extracted from `ct`.

   Process:
      1. Parses the output file to locate the `ct` variable.
      2. Extracts and processes the ciphertext values.
      3. Returns the extracted values as a list of integers.

   Note:
      - This method assumes that `ct` is formatted as a Python list in the output file.
      - If the `ct` variable is missing or formatted incorrectly, the method may fail.
   """
   def output_extraction(self, output_file):
        import re
        match = re.search(r'ct\s*=\s*\[(.*?)\]', output_file, re.DOTALL)
        if match:
            ct_values = match.group(1)
            return [int(num.strip()) for num in ct_values.split(',')]
        return []
   
   @trace.bundle(trainable=True)
   def decrypt(self, analysis_report, source_file, execution_output):
    """
    Attempts to decrypt an encrypted message using extracted LCG parameters.

    This function analyzes the execution output to extract LCG parameters and attempts 
    to reverse the encryption performed by the `Encryptor` class in `source.py`. 
    
    Args:
        analysis_report (str): The report generated from the analysis process, potentially 
                                containing metadata or insights relevant to encryption.
        source_file (str): The Python source file that contains the encryption logic.
        execution_output (str): The output generated during execution, from which 
                                encryption parameters are extracted.

    Returns:
        str: The decrypted flag if successful; otherwise, returns None.
        
    Process:
        1. Extracts the LCG parameters (modulus, multiplier, increment, seed).
        2. Attempts to reconstruct the random sequence used in encryption.
        3. Applies the inverse transformation to retrieve the original plaintext.

    Note:
        - The success of this function depends on the predictability of the LCG sequence.
        - Avoid modifying helper functions outside this method, as it may disrupt execution.
    """
    import os
    from functools import reduce
    from math import gcd
    def crack_unknown_increment(states, modulus, multiplier):
       return modulus, multiplier, increment


    def crack_unknown_multiplier(states, modulus):
       return crack_unknown_increment(states, modulus, multiplier)


    def crack_unknown_modulus(states):
       return crack_unknown_multiplier(states, modulus)

    ct = self.output_extraction(execution_output)
    flag = None
    return flag

agent = SimpleDecryptorAgent()
print(agent.parameters()[0].data)
agent.save("agent.pkl")
agent.load("agent.pkl")

predicted_flag = agent.run()
predicted_flag.backward(visualize=True)