from opto import trace
import subprocess
import re
import crypto

from opto.trace.operators import call_llm
from sympy.ntheory.modular import crt
@trace.model
class SimpleDecryptorAgent:

   """
   This class initializer reads the contents of specific files and stores them 
   in string variables for later access. 

   Args:
      None

   Attributes:
      source_file (str): The contents of the `source.py` file, read as a string.
      output (str): The contents of the `output.txt` file, read as a string.

   Returns:
      None (since this is an initializer method)
      
   Note:
      - The method reads files directly from the specified file paths.
      - Ensure that the files exist at the given paths; otherwise, this will raise a `FileNotFoundError`.
      - The files are read in their entirety and stored in memory as strings.
   """
   def __init__(self):
      self.source_file = open("/Users/anonymous/trace-security/security/crypto/permuted/resources/source.py").read()
      self.output = open("/Users/anonymous/trace-security/security/crypto/permuted/resources/output.txt").read()

   """
   The run method. This is the collective list of actions performed 
   by the agent in order to achieve the task provided.

   Task: In this case, the task is to decrypt the flag utilized in source.py, which 
   is contained in secret.py. The program does this by taking in the string values
   of the source files and the output, compiling an initial report on the content of such strings.
   """
   def run(self):
      print("RUN METHOD INITIATED")
      analysis_report = self.analyze_source_file(self.source_file, self.output)
      return self.decrypt(analysis_report, self.source_file, self.output)
   
   """
   Analyzes the given source file for vulnerabilities using an AI model.

   Args:
      source_file (str): The contents of the source code file to be analyzed.
      execution_output (str): The output produced by executing the source file, 
                              which may provide additional context for analysis.

   Returns:
      str: The response from the AI model containing an analysis of potential 
            vulnerabilities, including a chain of thought reasoning.

   Process:
      1. Prints a message indicating that the source file is being analyzed.
      2. Constructs a prompt instructing the AI to analyze vulnerabilities in the given source code.
      3. Calls the AI model with the prompt and the concatenated source code and execution output.
      4. Returns the AI-generated analysis.

   Note:
      - The effectiveness of the analysis depends on the AI model's ability to interpret
         the source code and execution output.
      - Ensure `trace.node()` and `call_llm()` are correctly implemented and accessible.
   """
   def analyze_source_file(self, source_file, execution_output):
      print("ANALYZING SOURCE FILE")
      analyze_prompt = trace.node("You are a cybersecurity agent, analyze the vunerability in this source file. Give me chain of thought.")
      return call_llm(analyze_prompt, source_file + '\n\nExecution output:' + execution_output)
   
   """
   Extracts the values of `g`, `A`, `B`, and `c` from the given output file content.

   Args:
      output_file (str): The contents of the output file as a string.

   Returns:
      tuple: A tuple containing three extracted values:
         - g (list): The value of list `g` extracted from the content.
         - A (list): The value of list `A` extracted from the content.
         - B (list): The value of list `B` extracted from the content.
         - c (list): The value of list `c` or the ciphertext from the content.

   Note:
      - Assumes the format of the content follows a specific structure with `n =`, `e =`, and `c =` appearing in order.
      - If any of these markers are missing, the method may produce unexpected results.
   """
   def output_extraction(self, output_file):
      file_path = output_file
      def extract_list_values(filename, variable_name):
        with open(filename, 'r') as file:
            content = file.read()
        
        # Use regex to find the list assigned to the variable
        pattern = rf'{variable_name}\s*=\s*\[([^\]]+)\]'
        match = re.search(pattern, content, re.DOTALL)
        
        if match:
            values = match.group(1)
            values_list = [int(value.strip()) for value in values.split(',')]
            return values_list
        else:
            return None
      g = [5, 4]#extract_list_values(file_path, "g")
      A = [3, 6]#extract_list_values(file_path, "A")
      B = [2, 9]#extract_list_values(file_path, "B")
      c = b'\x89\xba1J\x9c\xfd\xe8\xd0\xe5A*\xa0\rq?!wg\xb0\x85\xeb\xce\x9f\x06\xcbG\x84O\xed\xdb\xcd\xc2\x188\x0cT\xa0\xaaH\x0c\x9e9\xe7\x9d@R\x9b\xbd'

   
   @trace.bundle(trainable=True)
   def decrypt(self, analysis_report, source_file, execution_output):
    """
    Decrypts an encrypted message exploiting cycles and symmetric groups 
     in the implementation of the Diffie Hellman parameters extracted from execution output.

    This function analyzes the execution output to extract necessary parameters and attempts 
    to decrypt the encrypted message. It utilizes `output_extraction()` to retrieve 
    the g (generators), a (private key #1), b (private key #2), and ciphertext (c). 

    Args:
        analysis_report (str): The report generated from the analysis process, potentially 
                               containing metadata or insights relevant to encryption.
        source_file (str): The Python source file that contains the encryption logic.
        execution_output (str): The output generated during execution, from which 
                                encryption parameters are extracted.

    Returns: 
        bytes: The decrypted flag if successful; otherwise, returns None.
    
    If you write helper functions, make sure that they are embedded in this function.
    Do not add them to the class. Otherwise, the program may not run correctly.
    Double check for syntax errors - it is better that a program run incorrectly
    than not run at all.

    Make sure you avoid the following error: 
    (SyntaxError) unexpected character after line continuation character
    """
    from sympy.ntheory.modular import crt
    from hashlib import sha256
    from Crypto.Util.number import bytes_to_long, long_to_bytes
    from Crypto.Hash import SHA256
    from Crypto.Cipher import PKCS1_OAEP, AES 
    from resources.source import Permutation
    def dlp(g, h):
      # g is base
      # h is result
      g_cycles = g.cycles()
      h_cycles = h.cycles()

      print('g cycles:', g_cycles)
      print('h cycles:', h_cycles)

      G = []
      H = []

      for i in range(g.length):
         for j, c in enumerate(g_cycles):
               if i in c:
                  G.append((j, c.index(i)))

         for j, c in enumerate(h_cycles):
               if i in c:
                  H.append((j, c.index(i)))

      print('G:', G)
      print('H:', H)

      First = []
      Second = []

      for c in h_cycles:
         First.append(c[0])
         Second.append(c[1 % len(c)])

      print('first:', First)
      print('second:', Second)

      D = []
      L = []
      for i in range(len(Second)):
         dist = G[Second[i]][1] - G[First[i]][1]
         D.append(dist)
         L.append(len(h_cycles[i]))

      print('D:', D)
      print('L:', L)

      alpha = crt(L, D)
      return int(alpha[0])
    
    g, A, B, c = self.output_extraction("/Users/anonymous/trace-security/security/crypto/permuted/resources/output.txt")
    g = Permutation(g)
    A = Permutation(A)
    B = Permutation(B)

    a = dlp(g, A)
    # decrypt
    C = B**a

    sec = tuple(C.mapping)
    sec = hash(sec)
    sec = long_to_bytes(sec)

    hash = sha256()
    hash.update(sec)

    key = hash.digest()[16:32]
    iv = b"mg'g\xce\x08\xdbYN2\x89\xad\xedlY\xb9"
    cipher = AES.new(key, AES.MODE_CBC, iv)
    decrypted = cipher.decrypt(c)
    print('Flag:', decrypted)
    flag = None
    return flag

agent = SimpleDecryptorAgent()
print(agent.parameters()[0].data)
agent.save("agent.pkl")
agent.load("agent.pkl")

predicted_flag = agent.run()
predicted_flag.backward(visualize=True)