import sys
import os.path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
import random
from tqdm import tqdm
import concurrent.futures
import json
import argparse
import main
import generator
import api_utils as utils
import prompt_optimization.prompts as prompts


def select_k_from_n_excluding_i(n, k, i):
    numbers = list(range(n))
    if i in numbers:
        numbers.remove(i)

    if k > len(numbers):
        raise ValueError("Cannot select k numbers from the remaining numbers.")
    selected_numbers = random.sample(numbers, k)

    return selected_numbers


def parse_tagged_text(text, start_tag, end_tag):
    """ Parse text that is tagged with start and end tags."""
    texts = []
    while True:
        start_index = text.find(start_tag)
        if start_index == -1:
            break
        end_index = text.find(end_tag, start_index)
        if end_index == -1:
            break
        start_index += len(start_tag)
        select = text[start_index:end_index].strip().strip('"').strip('`')
        if select != "` and `" and select != "and":
            texts.append(text[start_index:end_index].strip())
        text = text[end_index + len(end_tag):]
    texts = list(set(texts))
    return texts


def prompt_spo_compare(ex, pos_idx, neg_idx, attr_pos, attr_neg, task_name='CUB_cuckoo', model_name='gemini'):
    if task_name == 'iNat_butterfly':
        pred_prompt = prompts.iNat_butterfly
    elif task_name == 'iNat_grass':
        pred_prompt = prompts.iNat_grass
    elif task_name == 'CUB_cuckoo':
        pred_prompt = prompts.CUB_cuckoo
    elif task_name == 'CUB_vireo':
        pred_prompt = prompts.CUB_vireo
    elif task_name == 'CUB_oriole':
        pred_prompt = prompts.CUB_oriole
    elif task_name == 'Stanford_terrier':
        pred_prompt = prompts.Stanford_terrier
    elif task_name == 'vegfru_1':
        pred_prompt = prompts.vegfru_1
    elif task_name == 'vegfru_2':
        pred_prompt = prompts.vegfru_2
    else:
        raise Exception(f"Unsupported task: {task_name}")

    random_bit = random.randint(0, 1)

    if random_bit == 0:
        pred_prompt += f"\nThere are two descriptions of the given images generated by two different prompts:\n\nText 1:\n{attr_pos}\n\nText 2:\n{attr_neg}\n\nWhich description better describes the image? The first text or the second text?"
        pred_prompt += "Provide your analysis and the choice you believe is better, using XML tags to encapsulate your response.\n\n<analyse>Some analysis</analyse>\n<choose>First/Second (the better answer in your opinion)</choose>\n"
        # pred_prompt = f"Text 1:\n{attr_pos}\n\nText 2:\n{attr_neg}\n\nWhich description better describes the image? The first text or the second text?"
        response = utils.gpt4o(pred_prompt, [ex['img_path']], max_tokens=1024, temperature=0.7)[0]
        prediction = parse_tagged_text(response, "<choose>", "</choose>")[0]
        if 'first' in prediction.lower() and 'second' not in prediction.lower():
            answer = 1
        else:
            answer = 0
    else:
        pred_prompt += f"\nThere are two descriptions of the given images generated by two different prompts:\n\nText 1:\n{attr_neg}\n\nText 2:\n{attr_pos}\n\nWhich description better describes the image? The first text or the second text?"
        pred_prompt += "Provide your analysis and the choice you believe is better, using XML tags to encapsulate your response.\n\n<analyse>Some analysis</analyse>\n<choose>First/Second (the better answer in your opinion)</choose>\n"
        response = utils.gpt4o(pred_prompt, [ex['img_path']], max_tokens=1024, temperature=0.7)[0]
        prediction = parse_tagged_text(response, "<choose>", "</choose>")[0]
        if 'second' in prediction.lower() and 'first' not in prediction.lower():
            answer = 1
        else:
            answer = 0

    return answer, ex, pos_idx, neg_idx


def run_evaluate(exs, pos_idx, neg_idx, attr_pos, attr_neg, task_name, model_name='gemini'):
    examples, preds = [], []

    with concurrent.futures.ProcessPoolExecutor(max_workers=8) as executor:
        futures = [executor.submit(prompt_spo_compare, ex, pos_idx, neg_idx, attr_pos[f'{ex}'], attr_neg[f'{ex}'],
                                   task_name, model_name) for ex in exs]

        for i, future in tqdm(enumerate(concurrent.futures.as_completed(futures)),
                              total=len(futures), desc='comparing two prompts (parallel)'):
            answer, ex, pos_idx, neg_idx = future.result()
            examples.append(ex)
            preds.append(answer)

    return examples, preds


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_name', default='CUB_cuckoo',
                        choices=['iNat_butterfly', 'iNat_grass', 'Stanford_terrier',
                                 'CUB_cuckoo', 'CUB_oriole', 'CUB_vireo', 'vegfru_1', 'vegfru_2'])
    parser.add_argument('--model', default='gemini', choices=['gemini', 'gpt4o', 'sglang_qwen'])
    parser.add_argument('--out_num', default='0')
    parser.add_argument('--max_threads', default=8, type=int)
    parser.add_argument('--data_dir', default='/datasets')
    parser.add_argument('--result_folder', default='spo')
    parser.add_argument('--mode', default='train')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--exp', default=15, type=int)
    parser.add_argument('--prompt_idx0', default=0, type=int)
    parser.add_argument('--prompt_idx1', default=1, type=int)
    parser.add_argument('--n_test', default=30, type=int)
    parser.add_argument('--generate', action='store_true', default=False)
    parser.add_argument('--parallel', action='store_true', default=False)

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = get_args()
    prompt_test_address = f"../{args.result_folder}/results/{args.exp}_{args.task_name}/{args.exp}_test_attr.json"
    prompt_address = f"../{args.result_folder}/results/{args.exp}_{args.task_name}/{args.exp}_{args.mode}_attr.json"

    with open(prompt_test_address, 'r') as json_file:
        attr_all = json.load(json_file)
    prompt_keys = list(attr_all.keys())
    with open(prompt_address, 'r') as json_file:
        attr_all = json.load(json_file)

    configs = vars(args)
    task = main.get_task_class(args)
    gpt4 = main.get_predictor(configs)
    gpt_generator = generator.AttrGredictor(configs)

    exs = task.get_even_exs(args.mode, args.n_test)
    if args.generate:
        attr_all = {}
        attr_all[f'{prompt_keys[args.prompt_idx0]}'], attr_all[f'{prompt_keys[args.prompt_idx1]}'] = {}, {}
        attr_all = generator.parallel_generate(gpt_generator, prompt_keys[args.prompt_idx0], exs, attr_all,
                                               args.max_threads)
        attr_all = generator.parallel_generate(gpt_generator, prompt_keys[args.prompt_idx1], exs, attr_all,
                                               args.max_threads)

    if args.parallel:
        examples, preds = run_evaluate(exs, args.prompt_idx0, args.prompt_idx1,
                                       attr_all[f'{prompt_keys[args.prompt_idx0]}'],
                                       attr_all[f'{prompt_keys[args.prompt_idx1]}'], args.task_name)
    else:
        preds = []
        for i in tqdm(range(len(exs)), desc=f"comparing two prompts (Single)"):
            answer, ex, pos_idx, neg_idx = prompt_spo_compare(exs[i], args.prompt_idx0, args.prompt_idx1,
                                                              attr_all[f'{prompt_keys[args.prompt_idx0]}'][f'{exs[i]}'],
                                                              attr_all[f'{prompt_keys[args.prompt_idx1]}'][f'{exs[i]}'],
                                                              args.task_name)
        preds.append(preds)

    print(sum(preds) / len(preds))
    print('Done!')
