import openai
import time
import re
import subprocess

# Insert your OpenAI API key
OPENAI_API_KEY = ""

# Set up OpenAI API key
openai.api_key = OPENAI_API_KEY

# Define the model and settings
model = "gpt-4o"
temperature = 0.2

#Max list
max_rank_list=[176, 176, 176]

#Define hyperparaneters for early sttoping
no_improvement_count = 0
patience=5
min_delta=0

#Define file paths
bash_call = 'bash'
script_path='./run_FCTN_decomposition_image.sh'
text_file_path = './results.txt'

#Define system message and initial prompt
system_message = """
You are a RGB image expert specialized in tensor decomposition. Your task is to analyze RGB image tensors and suggest the optimal ranks for a fully connected tensor network decomposition (FCTND) applied to them. You need to provide an array of ranks which minimizes the loss function, which is the natural log of the sum of the compression rate and 10 times the approximation error. The compression ratio is calculated as the number of parameters in the compressed FCTND format divided by the original number of parameters of the uncompressed tensor and the approximation error is the relative square error between the original and approximate tensor. Work your suggestions out step-by-step based on rigorous reasoning and RGB image domain knowledge. Explain your final suggestions in a logical, concise manner.
"""

initial_prompt = """
We are working with a fully connected 3rd-order tensor representing RGB image data with the following modes:
- Mode 1 of size 144 : The width of the RGB image. There are 144 pixels in the width of the RGB image.
- Mode 2 of size 176 : The height of the RGB image. There are 176 pixels in the width of the RGB image.
- Mode 3 of size 3 : The RGB channels of the RGB image. There are three RGB channels, each representing red, green, and blue.

There are 3 ranks for such an order-4 tensor to set in total. 

Your task is to suggest the optimal rank for each connection in a fully connected tensor network decomposition. The loss function to minimize is a natural log of the sum of the compression rate and 10 times the tensor approximation error which is the relative square error between the original and approximate tensor. Provide your response in the following format:
1. Take a deep breath and reason step-by-step about the intrinsic interactions between every pair of modes based on your understanding of the relationships between the width, the height, and the RGB channels in RGB image data. It is important to reason about those intrinsic interactions based on interpretable factors.
2. Based on your reasoning, output an array of numbers where each number represents the rank for the connection between every pair of modes (the width, the height, and the RGB channels in RGB image data). 


Output format:
Reasoning: Reason about the intrinsic interactions between every pair of modes based on your understanding of RGB image data.

Rank Array: [Rank for the connection between (Mode 1, Mode 2), Rank for the connection between (Mode 1, Mode 3), Rank for the connection between (Mode 2, Mode 3).] End the output message with an array of numerical values for these 3 ranks. All the entries should be greater than or equal to 1 where a value of 1 means no connection between the modes. Ensure that none of the ranks exceed the given constraints.
"""


#Function definitions
def generate_response(system_message, prompt, model, temperature):
    response = openai.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": prompt},
        ],
        temperature=temperature,
        max_tokens=1000  # Adjust depending on the length of the expected output
    )

    return response.choices[0].message.content


def call_bash_script(script_path):
    try:
        subprocess.call( script_path , shell=True)
    except Exception as e:
        print(f"Error calling bash script: {e}")

def get_last_objective_function(file_path):
    last_line = ""
    output = [0.0, 0.0, 0.0]
    # Open the file and read lines
    with open(file_path, 'r') as file:
        lines = file.readlines()

        # Get the last non-empty line
        for line in reversed(lines):
            if line.strip():  # Skip empty lines
                last_line = line.strip()
                break

    # Parse the last line to extract the loss Function value
    if last_line:
        try:
            # Split the line based on commas and look for "loss Function"
            parts = last_line.split(',')
            for part in parts:
                if "Objective Function" in part:
                    objective_value = part.split('=')[1].strip()
                    output[0] = float(objective_value)
                if "Approximation Error" in part:
                    approximation_error = part.split('=')[1].strip()
                    output[1] = float(approximation_error)
                if "Compression Rate" in part:
                    compression_rate = part.split('=')[1].strip()
                    output[2] = float(compression_rate)
                if all(value != 0.0 for value in output):
                    return output
        except (IndexError, ValueError):
            print("Error in parsing the loss Function.")

    return None


#Main loop - iteraten 0
print("Running initial prompt:")
initial_response = generate_response(system_message, initial_prompt, model, temperature)
print(initial_response)

rank_array = re.findall(r'[\[\]\\]*\s*([\d]+,\s*[\d]+,\s*[\d]+)\s*[\[\]\\]*', initial_response)[-1]
rank_list = list(map(int, rank_array.split(',')))

min_list=[min(a,b) for a,b in zip(rank_list, max_rank_list)]

input_list=[bash_call, script_path] + min_list
output_string = ' '.join(map(str, input_list))
call_bash_script(output_string)

returned_scores = get_last_objective_function(text_file_path)
if returned_scores is not None:
    last_objective_function = returned_scores[0]
    last_approximation_error = returned_scores[1]
    last_compression_rate = returned_scores[2]
    print(f"Loss Function at iteraten 0: {last_objective_function}")
    print(f"last_approximation_error at iteraten 0: {last_approximation_error}")
    print(f"last_compression_rate at iteraten 0: {last_compression_rate}")
    best_objective_function=last_objective_function
    best_approximation_error=last_approximation_error
    best_compression_rate=last_compression_rate
    best_list=min_list

for iteraten in range(500):
    # Define Iterative Prompt
    iterative_prompt = f"""
    The last rank array is {min_list} with a total loss function of {last_objective_function}, which is the natural log of the sum of the current compression rate of {last_compression_rate} and ten times the current approximation error of {last_approximation_error}. The lowest total loss function of {best_objective_function}, which is the natural log of the sum of the compression rate of {best_compression_rate} and ten times the approximation error of {best_approximation_error}, is found by using the rank array {best_list}. The loss function to minimize is a natural log of the sum of the compression rate and 10 times the tensor approximation error which is the relative square error between the original and approximate tensor. Take a deep breath, refine the rank suggestions to make the loss function smaller, and justify any changes in ranks. Keep in mind that increasing the ranks significantly decreses the approximation error, while it increases the compression rate. However, if the compression rate is already very low compared to the approximation error (for example, the compression rate is smaller than half of the approximation error), increasing the ranks (such as doubling it) to decrease the approximation error will usually lead to a lower loss function. Also, if the compression rate is already very high compared to the approximation error (for example, the compression rate is larger than two times the approximation error), reducing the ranks (such as cutting it in half) to decrease the compression rate will usually lead to a lower loss function.

    When refining the ranks, consider how each mode (width, height, RGB channels) interacts with the others and how reducing or increasing the rank will affect the overall decomposition accuracy. You are encouraged to be explorative to try small and large rank value changes in this process. You should never try the same set of ranks more than once. Trying the same set of ranks more than once wastes computation resources and will not lead to a different outcome.

    Provide the adjusted ranks and reason for the changes in the following format:

    Output format:
    Detailed Reasoning: Take a deep breath to revise your previous reasoning and proposed changes, reason explicitly step-by-step about the possible factors that could impact between the intrinsic interactions between every pair of modes based on your understanding of RGB image data.

    Rank Array: [Rank for the connection between (Mode 1, Mode 2), Rank for the connection between (Mode 1, Mode 3), Rank for the connection between (Mode 2, Mode 3).] 
    End the output message with an array of numerical values for these 3 ranks. All the entries should be greater than or equal to 1 where a value of 1 means no connection between the modes. 
    """
    

    
    print(f"\nRunning iteraten {iteraten + 1} of iterative prompt:")
    iteraten_response = generate_response(system_message, iterative_prompt, model, temperature)
    print(iterative_prompt)
    
    print(iteraten_response)
    rank_array = re.findall(r'[\[\]\\]*\s*([\d]+,\s*[\d]+,\s*[\d]+)\s*[\[\]\\]*', iteraten_response)[-1]
    rank_list = list(map(int, rank_array.split(',')))

    min_list=[min(a,b) for a,b in zip(rank_list, max_rank_list)]
    print('current rank array is ', min_list)

    input_list=[bash_call, script_path] + min_list

    # Convert the list to a space-separated string without commas
    output_string = ' '.join(map(str, input_list))

    # Call bash script after each iteraten
    call_bash_script(output_string)

    current_scores = get_last_objective_function(text_file_path)
    if current_scores is not None:
        current_objective_function = current_scores[0]
        current_approximation_error = current_scores[1]
        current_compression_rate = current_scores[2]

    if current_objective_function is not None:
      print(f"Loss Function at iteraten {iteraten + 1}: {current_objective_function}")

    else:
      print("Could not retrieve the loss function.")

    if current_objective_function < best_objective_function:
        best_objective_function = current_objective_function
        best_approximation_error = current_approximation_error
        best_compression_rate = current_compression_rate
        best_list=min_list

        no_improvement_count = 0 #reset count

    else:
        no_improvement_count+=1

        if no_improvement_count>=patience:
            print(f"Stopping early at iteraten {iteraten + 1} due to no improvement.")
            print(f"Best loss function: {best_objective_function}")
            print(f"Best rank array: {best_list}")
            break

    last_objective_function = current_objective_function
    last_approximation_error = current_approximation_error
    last_compression_rate = current_compression_rate