import os
import re
import json
import time
import numpy as np
from datetime import datetime
from tqdm import tqdm
from call_gpt import call_gpt
from prompts.time_series_prediction import standard_prompt, llm_ar_prompt
import argparse
import tiktoken

parser = argparse.ArgumentParser(description='Run time series prediction task')

parser.add_argument('--dataset_dir', type=str, default='../dataset', help='Directory of the dataset')
parser.add_argument('--model', type=str, default='gpt-4-vision-preview', help='Model to use')
parser.add_argument('--steps', type=int, default=10, help='Number of values to predict')
parser.add_argument('--max_input_length', type=int, default=100, help='Maximum number of input sequence length')
parser.add_argument('--log_dir', type=str, default='log', help='Directory for logs')

args = parser.parse_args()
dataset_dir = args.dataset_dir
task = 'time_series_prediction'
model = args.model
method = 'mip'
steps = args.steps
max_input_length = args.max_input_length

method = method.lower()

# log_dir_base = os.path.join(args.log_dir, task)

# current_time_str = datetime.now().strftime('%Y%m%d_%H%M%S')

# if not os.path.exists(log_dir_base):
#     os.makedirs(log_dir_base)

# log_filename = os.path.join(log_dir_base, f'{model}_{method}_{current_time_str}.csv')
# log_file = open(log_filename, 'w', encoding='utf8')
# log_file.write('predict_id,next_values,mae,mape\n')
# log_file.flush()


encoding = tiktoken.encoding_for_model(model)


def get_avg_tokens_per_step(input_str, time_sep=','):
    tokens = encoding.encode(input_str)
    input_tokens = len(tokens)
    input_steps = len(input_str.split(time_sep))
    tokens_per_step = input_tokens / input_steps
    return tokens_per_step


def calculate_mae(actual, predicted):
    """
    Calculate Mean Absolute Error (MAE) between two lists of floats.

    :param actual: List of actual values.
    :param predicted: List of predicted values.
    :return: MAE as a float.
    """
    if len(actual) != len(predicted):
        raise ValueError("The length of actual and predicted lists must be the same.")

    mae = sum(abs(a - p) for a, p in zip(actual, predicted)) / len(actual)
    return mae


def calculate_mape(actual, predicted):
    """
    Calculate Mean Absolute Percentage Error (MAPE) between two lists of floats.

    :param actual: List of actual values.
    :param predicted: List of predicted values.
    :return: MAPE as a float.
    """
    if len(actual) != len(predicted):
        raise ValueError("The length of actual and predicted lists must be the same.")

    mape = sum(abs((a - p) / a) for a, p in zip(actual, predicted) if a != 0) / len(actual)
    return mape * 100


metadata_path = os.path.join(dataset_dir, task, 'task.json')
with open(metadata_path, 'r', encoding='utf8') as f:
    metadata = json.load(f)

mae_list = []
mape_list = []
std_list = []

for sid, item in enumerate(tqdm(metadata)):
    temperature = 0.7
    input_sequence = item['input'][-max_input_length:]

    input_str = ','.join(list(map(str, input_sequence))) + ','
    prompt = standard_prompt.format(sequence=input_str)

    print(prompt)

    # # Post precessing
    # if not pred_sequence:
    #     pred_sequence = [input_sequence[-1]] * steps
    # elif len(pred_sequence) < steps:
    #     pred_sequence.extend([pred_sequence[-1]] * (steps - len(pred_sequence)))
    # elif len(pred_sequence) > steps:
    #     pred_sequence = pred_sequence[:steps]

    # ground_truth = item['output']
    # ground_truth = ground_truth[:steps]
    # pred_sequence = pred_sequence[:len(ground_truth)]
    # mae = calculate_mae(ground_truth, pred_sequence)
    # mape = calculate_mape(ground_truth, pred_sequence)
    # mae_list.append(mae)
    # mape_list.append(mape)

    # log_file.write(f'{sid},{" ".join(list(map(str, pred_sequence)))},{mae},{mape}\n')
    # log_file.flush()


# # Calculate and display the average MAE and MAPE
# avg_mae = sum(mae_list) / len(mae_list)
# avg_mape = sum(mape_list) / len(mape_list)

# # Print the performance report
# print(f"Performance Report - Model: {model}, Method: {method}")
# print(f"Total Samples: {len(mae_list)}")
# print(f"Average MAE: {avg_mae}")
# print(f"Average MAPE: {avg_mape}%")

# if method == 'llm-ar-sc':
#     avg_std = sum(std_list) / len(std_list)
#     print(f"Average standard deviation: {avg_std}")
