import ast
import json
import sys, traceback
import argparse
import requests
from tqdm import tqdm
from typing import Dict, Union, List
from pathlib import Path
from PIL import Image
from io import BytesIO
import pandas as pd
import time
from tenacity import retry, wait_exponential, stop_after_attempt
import warnings

from utils import get_header

warnings.filterwarnings('ignore')


@retry(wait=wait_exponential(multiplier=1, min=1, max=20), stop=stop_after_attempt(5))
def generate_image(
    header: Dict[str, str],
    api_metadata: Dict[str, str],
    task_payload: Dict,
):
    '''
    Run a trial of the serial search task.

    Parameters:
    header(Dict[str, str]): The API information.
    api_metadata (str): Metadata describing the relevant endpoints for the API request.
    task_payload (Dict): The payload for the vision model request.

    Returns:
    img: The image generated by the vision model.
    '''
    # Until the model provides a valid response, keep trying.
    response = requests.post(api_metadata['tti_endpoint'], headers=header, json=task_payload, timeout=45)
    # Check for easily-avoidable errors
    if 'error' in response.json():
        print('failed TTI request')
        print(response.json()['error']['message'])
        raise ValueError('Returned error: \n' + response.json()['error']['message'])
    
    # Extract the responses from the vision model and parse them with the parsing model.
    image_url = response.json()['data'][0]['url']
    image = requests.get(image_url).content
    image = Image.open(BytesIO(image))
    return image


def save_results(results_df: pd.DataFrame, results_file: str=None):
    if results_file:
        results_df.to_csv(results_file, index=False)
    else:
        filename = f'results_{time.time()}.csv'
        results_df.to_csv(filename, index=False)


def parse_args() -> argparse.Namespace:
    '''
    Parse command line arguments.

    Returns:
    argparse.Namespace: The parsed command line arguments.
    '''
    parser = argparse.ArgumentParser(description='Run trials for the specified task.')
    parser.add_argument('--task', type=str, required=True, choices=['counting', 'binding'], help='The name of the task.')
    parser.add_argument('--task_file', type=str, required=True, help='The file containing the task metadata.')
    parser.add_argument('--results_file', type=str, default=None, help='The file to save the results to.')
    parser.add_argument('--api_file', type=str, default='api_metadata.json', help='Location of the file containing api keys and endpoints.')
    parser.add_argument('--task_payload', type=str, default='payloads/dalle3.json', help='The path to the task payload JSON file.')
    parser.add_argument('--n_trials', type=int, default=None, help='The number of trials to run.')
    parser.add_argument('--api', type=str, default='azure', help='Which API to use for the requests.')
    parser.add_argument('--log_interval', type=int, default=10, help='The interval at which to save the results.')
    parser.add_argument('--sleep', type=int, default=0, help='The time to sleep between requests.')
    return parser.parse_args()


def main():
    # Parse command line arguments.
    args = parse_args()
    print('Running trials for task:', args.task)

    # Load the relevant payloads and prompts.
    task_payload = json.load(open(args.task_payload, 'r'))
    api_metadata = json.load(open(args.api_file, 'r'))

    # OpenAI API Key and header.
    header = get_header(api_metadata, model=args.api)
    api_metadata = api_metadata[args.api]

    # Load the task metadata and results.
    try:
        results_df = pd.read_csv(args.results_file)
    except (FileNotFoundError, ValueError):
        # If no valid results_df was provided, open the task metadata and construct a new one.
        results_df = pd.read_csv(args.task_file)

    # Shuffle the trials, extracting n_trials if the argument was specified
    if args.n_trials:
        results_df = results_df.sample(n=args.n_trials).reset_index(drop=True)
    else:
        results_df = results_df.sample(frac=1).reset_index(drop=True)

    # Run all the trials.
    for i, trial in tqdm(results_df.iterrows(), total=len(results_df)):

        # Only run the trial if it hasn't been run before.
        if not trial.completed:
            try:
                # Generate and save the image.
                task_payload['prompt'] = trial.prompt
                image = generate_image(header, api_metadata, task_payload)
                image.save(trial.path)
                time.sleep(args.sleep)
            except Exception as e:
                print(f'Failed on trial {i} with error: {e}')
                _, _, tb = sys.exc_info()
                traceback.print_tb(tb)
                break  # Stop the loop if there is an error and save the progress.

        # Save the progress at log_interval.
        if i % args.log_interval == 0:
            save_results(results_df, args.results_file)

    # Save the final results.
    save_results(results_df, args.results_file)

if __name__ == '__main__':
    main()