import sys
import json
import re
import cv2
import base64
import copy
import numpy as np
from SYSTEM_MESSAGE import SYSTEM_TEXT_IMAGE_EXIST_FOR_REFINE, SYSTEM_TEXT_IMAGE_UNIQUENESS_FOR_REFINE, SYSTEM_TEXT_IMAGE_EXIST

from call_llms import response_generator, response_generator_text
from config import model_name, ScanNet_Frame_ROOT, ScanRefer_filtered_path

frames_root = ScanNet_Frame_ROOT + '/{}/color/{}'
model = model_name   # model = 'gpt_4o'  gpt-4.1-mini

ScanRefer_data = json.load(open(ScanRefer_filtered_path))
# =====================
def item_filter(all_data_list, filter_key):
    filtered_list = []
    good_list = []
    for d in all_data_list:
        if not filter_key in d or d[filter_key] == -100:
            filtered_list.append(d)
        else:
            good_list.append(d)
    return filtered_list, good_list

def item_uniqueness_filter(all_data_list, filter_key):
    filtered_list = []
    for d in all_data_list:
        if not filter_key in d:
            try:
                if f'{model}_positive_image_existence' in d:
                    if all(isinstance(x, (int, float)) for x in d[f'{model}_positive_image_existence']):
                        filtered_list.append(d)
            except:
                pass
    return filtered_list
# =====================

def get_llama_score_image(description, image_path, system_message, appendix=None, lock=None, max_token=1000):
    if not appendix:
        USER_TEXT = 'The **DESCRIPTION** is {}.\n '.format(description)
    else:
        USER_TEXT = ('The **DESCRIPTION** is {}.\n '
                     'Other **Additional Information** is: \n "{}"').format(description, appendix)
    llama_answer = response_generator(system_message, USER_TEXT, image_path, max_token=max_token, lock=lock)
    llama_answer = llama_answer.content
    score_list = re.findall(r"Score:\s*(\d+)", llama_answer)
    if len(score_list) > 0:
        score = int(score_list[-1])

    else:
        last_five_lines = llama_answer.split('\n')[-3:]
        all_numbers = [num for s in last_five_lines for num in re.findall(r"\d+", s)]
        if len(all_numbers) > 0:
            score = int(all_numbers[-1])
        else:
            score = -1

    return score, llama_answer.split('\n')[-2], llama_answer

def pair_images_build(scene, frames, pair_list):
    c_images = []
    for pair in pair_list:
        img1 = cv2.imread(frames_root.format(scene, frames[pair[0]]))
        img2 = cv2.imread(frames_root.format(scene, frames[pair[1]]))

        if img1.shape[0] != img2.shape[0]:
            img2 = cv2.resize(img2, (int(img2.shape[1] * img1.shape[0] / img2.shape[0]), img1.shape[0]))

        gap_width = 20
        gap = 255 * np.ones((img1.shape[0], gap_width, 3), dtype=np.uint8)
        combined = np.hstack((img1, gap, img2))
        _, buffer = cv2.imencode('.jpg', combined)

        img_base64 = base64.b64encode(buffer).decode('utf-8')

        c_images.append(img_base64)
    return c_images

# =============
def image_exist(data_item, system_message, job_key, pos_mis='positive_frames', frame_number=3, max_token=1000, use_summarize=True, lock=None):
    """
    """

    result_list = []

    scene_name = data_item['scene_id']
    frame_list = data_item[pos_mis][:frame_number]
    for frame_name in frame_list:
        try:
            frame_path = frames_root.format(scene_name, frame_name)
            score, summarize, llama_answer = get_llama_score_image(data_item['description'], frame_path, system_message, lock=lock, max_token=max_token)
            if use_summarize:
                result_list.append((summarize, score))
            else:
                result_list.append(score)
        except:
            if use_summarize:
                result_list.append((None, -100))
            else:
                result_list.append(-100)

    data_item.update({job_key: result_list})
    return data_item

def image_exist_refine(data_item, frame, system_message, appendix=None, lock=None):
    try:
        score, summarize, llama_answer = get_llama_score_image(data_item['description'], frame, system_message, appendix=appendix, lock=lock)
    except:
        score = -100
        summarize = None

    return score, summarize

# =============
def image_uniqueness(data_item, system_message, job_key, pos_mis='misleading_frames', frame_number=3, max_token=1000, use_summarize=True, lock=None):

    result_list = []

    scene_name = data_item['scene_id']
    existing_judge = [x[1] for x in data_item[f'{model}_positive_image_existence']]
    max_exit = max(existing_judge)
    for e_idx in range(len(existing_judge)):
        e_score = existing_judge[e_idx]
        if e_score >= 7 or e_score == max_exit:
            frame_name = data_item[pos_mis][e_idx]
            try:
                # ===================================
                frame_path = frames_root.format(scene_name, frame_name)
                score, summarize, llama_answer = get_llama_score_image(data_item['description'], frame_path, system_message, lock=lock)

                if use_summarize:
                    result_list.append((summarize, score))
                else:
                    result_list.append(score)
            except:
                if use_summarize:
                    result_list.append((None, -100))
                else:
                    result_list.append(-100)
        else:
            result_list.append(None)

    data_item.update({job_key: result_list})
    return data_item

def image_uniqueness_refine(data_item, frame, system_message, appendix=None, lock=None):
    try:
        score, summarize, llama_answer = get_llama_score_image(data_item['description'], frame, system_message, appendix=appendix, lock=lock)
    except:
        score = -100
        summarize = None
    return score, summarize

# =============
def sentence_judge(data_item, system_message, job_key, pos_mis='misleading_frames', frame_number=3, max_token=1000, use_summarize=False, lock=None):
    description = data_item['description']

    USER_TEXT = 'The **description** needs analysis is: "{}".'.format(description)
    llama_answer = response_generator_text(system_message, USER_TEXT, max_token=max_token, lock=lock)
    llama_answer = llama_answer.content
    score_list = re.findall(r"Score:\s*(\d+)", llama_answer)
    if len(score_list) > 0:
        score = int(score_list[-1])
    else:
        last_five_lines = llama_answer.split('\n')[-3:]
        all_numbers = [num for s in last_five_lines for num in re.findall(r"\d+", s)]
        if len(all_numbers) > 0:
            score = int(all_numbers[-1])
        else:
            score = -100

    data_item.update({job_key: score})
    try:
        data_item.update({job_key + '_Rationale': llama_answer.split('\n')[-2]})
    except:
        data_item.update({job_key + '_Rationale': 'No related reviews'})
    return data_item

def target_searching_and_group(data_item):
    sentence_groups_list = []
    scene_name = data_item['scene_id']
    obj_id = data_item['object_id']
    description = data_item['description']
    for item in ScanRefer_data:
        if item['scene_id'] == scene_name:
            if item['object_id'] == obj_id and item['description'] != description:
                sentence_groups_list.append(item['description'])

    if len(sentence_groups_list) >= 3:
        dict = {"object_category": data_item['object_name'],
                "description group G": {}}
        for idx in range(len(sentence_groups_list)):
            dict['description group G'].update({'id_{}'.format(idx+1): sentence_groups_list[idx]})
        return json.dumps(dict), sentence_groups_list
    else:
        return None, None

def new_group(data_item, group_list):
    dict = {"object_category": data_item['object_name'],
            "description D": data_item['description'],
            "description group G": {}}
    for idx in range(len(group_list)):
        dict['description group G'].update({'id_{}'.format(idx + 1): group_list[idx]})
    return json.dumps(dict)

def sentence_group_judge(data_item, system_message, job_key, pos_mis='misleading_frames', frame_number=3, max_token=1000, use_summarize=False, lock=None):
    # ===== part 1 ======
    message, group_list = target_searching_and_group(data_item)
    if message is not None:
        USER_TEXT = 'The ***JSON FORM MESSAGE** needs analysis is: "{}".'.format(message)
        llama_answer = response_generator_text(system_message[0], USER_TEXT, max_token=max_token, lock=lock)
        llama_answer = llama_answer.content
        score_list = re.findall(r"\s*(\d+)", llama_answer.split('\n')[-1])
        if len(score_list) == len(group_list):
            refresh_group = [group for group, score in zip(group_list, score_list) if int(score) > 0]
            USER_TEXT = 'The ***JSON FORM MESSAGE** needs analysis is: "{}".'.format(new_group(data_item, refresh_group))
            llama_answer = response_generator_text(system_message[1], USER_TEXT, max_token=max_token, lock=lock)
            llama_answer = llama_answer.content
            score_list = re.findall(r"Score:\s*(\d+)", llama_answer)
            if len(score_list) > 0:
                score = int(score_list[-1])
            else:
                last_five_lines = llama_answer.split('\n')[-3:]
                all_numbers = [num for s in last_five_lines for num in re.findall(r"\d+", s)]
                if len(all_numbers) > 0:
                    score = int(all_numbers[-1])
                else:
                    score = -100

            data_item.update({job_key: score})
            try:
                data_item.update({job_key + '_Rationale': llama_answer.split('\n')[-2]})
            except:
                data_item.update({job_key + '_Rationale': 'No related reviews'})
        else:
            data_item.update({job_key: 8})
            data_item.update({job_key + '_Rationale': 'No related reviews'})
    else:
        data_item.update({job_key: 8})
        data_item.update({job_key + '_Rationale': 'No related reviews'})
    return data_item

# ===========================================
""" Corroborative Refinement """
def refer_refine(data_item, system_message, job_key, pos_mis='misleading_frames', frame_number=3, max_token=1500, use_summarize=False, lock=None):
    """"""
    description = data_item['description']
    if job_key == f'{model}_refine_exist_pos':
        org_message = data_item[f'{model}_positive_image_existence']
    elif job_key == f'{model}_refine_unique_pos':
        org_message = data_item[f'{model}_positive_uniqueness']
    elif job_key == f'{model}_refine_exist_mis':
        org_message = data_item[f'{model}_misleading_image_existence']

    data_item.update({job_key: copy.deepcopy(org_message)})
    data_item.update({job_key + '_list': []})

    # ============
    USER_TEXT = 'Description: ' + description + '\n' + 'Messages: ' + str([{x: [org_message[x]]} for x in range(len(org_message))])
    # ==================================

    # ============
    llama_answer = response_generator_text(system_message, USER_TEXT, max_token=max_token, lock=lock)
    llama_answer = llama_answer.content

    if 'pos' in job_key:
        # ========
        line = -1
        for x in range(3):
            score_list = re.findall(r"\s*(\d+)", llama_answer.split('\n')[line])
            if len(score_list) > 0:
                break
            else:
                line -= 1
        score_list = [int(x) for x in score_list]
        if len(score_list) % 2 == 0 and len(score_list) >= 2:
            score_list_pair = list(zip(score_list[::2], score_list[1::2]))
            data_item[job_key + '_list'].append(llama_answer.split('\n')[-1])
        else:
            score_list_pair = None
            pass
        # =============================

        if score_list_pair:
            scene = data_item['scene_id']
            frames = data_item['positive_frames']
            constructed_images = pair_images_build(scene, frames, score_list_pair)

            assert len(constructed_images) == len(score_list_pair)

            for refine_i in range(len(constructed_images)):
                rationale = 'left view' + org_message[score_list_pair[refine_i][0]][0] + ' ## right view' + org_message[score_list_pair[refine_i][1]][0]
                # ======================
                if 'exist' in job_key:
                    refine_score, summarize = image_exist_refine(data_item, constructed_images[refine_i], SYSTEM_TEXT_IMAGE_EXIST_FOR_REFINE, rationale, lock=lock)
                    data_item[job_key][score_list_pair[refine_i][0]].append(summarize)
                    data_item[job_key][score_list_pair[refine_i][0]].append(refine_score)

                if 'exist' in job_key:
                    unique_score, summarize = image_uniqueness_refine(data_item, constructed_images[refine_i], SYSTEM_TEXT_IMAGE_UNIQUENESS_FOR_REFINE, lock=lock)
                else:
                    unique_score, summarize = image_uniqueness_refine(data_item, constructed_images[refine_i],
                                                                      SYSTEM_TEXT_IMAGE_UNIQUENESS_FOR_REFINE, rationale, lock=lock)
                data_item[job_key][score_list_pair[refine_i][0]].append(summarize)
                data_item[job_key][score_list_pair[refine_i][0]].append(unique_score)

    elif 'mis' in job_key:
        score_list = re.findall(r"\s*(\d+)", llama_answer.split('\n')[-1])
        score_list = [int(x) for x in score_list]
        data_item[job_key + '_list'].append(llama_answer.split('\n')[-1])

        if len(score_list) > 0:
            for refine_i in score_list:
                if refine_i < len(data_item['misleading_frames']):
                    frame_path = frames_root.format(data_item['scene_id'], data_item['misleading_frames'][refine_i])
                    refine_score, summarize = image_exist_refine(data_item, frame_path, SYSTEM_TEXT_IMAGE_EXIST)
                    data_item[job_key][refine_i].append(summarize)
                    data_item[job_key][refine_i].append(refine_score)

    return data_item
# ===========================================

# =============
def refer_judge(data_item, system_message, job_key, pos_mis='misleading_frames', frame_number=3, max_token=400, use_summarize=False, lock=None):
    description = data_item['description']
    pos_exist = [x[-1] if len(x) == 2 else x[3] for x in data_item[f'{model}_refine_exist_pos']]
    pos_exist_r = [x[0] if len(x) == 2 else x[2] for x in data_item[f'{model}_refine_exist_pos']]

    pos_unique = [-1 if len(x) == 2 else x[-1] for x in data_item[f'{model}_refine_exist_pos']]
    pos_unique_r = [None if len(x) == 2 else x[-2] for x in data_item[f'{model}_refine_exist_pos']]

    for idx in range(len(data_item[f'{model}_refine_unique_pos'])):
        uni = data_item[f'{model}_refine_unique_pos'][idx]
        if uni is not None:
            pos_unique[idx] = uni[-1] * 2
            pos_unique_r[idx] = uni[-2]
        else:
            pos_unique[idx] = None if pos_unique[idx] == -1 else pos_unique[idx]

    misleading_exist = [x[-1] if x[-1] != -100 else x[1] for x in data_item[f'{model}_refine_exist_mis']]
    misleading_exist_r = [x[-2] if x[-1] != -100 else x[0] for x in data_item[f'{model}_refine_exist_mis']]
    # ======================================================

    sentence_logical = data_item[f'{model}_sentence_logical']
    sentence_logical_r = data_item[f'{model}_sentence_logical_Rationale']

    sentence_group_logical = data_item[f'{model}_group_logical']
    sentence_group_logical_r = data_item[f'{model}_group_logical_Rationale']

    obj_number = data_item['same_obj_in_scene']

    # ============
    match_and_unique = []
    assert len(pos_exist) == len(pos_unique)
    for idx in range(len(pos_exist)):
        match_and_unique.append({pos_exist[idx]: pos_unique[idx]})

    JSON_MESSAGE = {
        'JUDGE MESSAGES': {
            'Distinguishability': match_and_unique,
            'Ambiguity': misleading_exist,
            'Logical': sentence_logical,
            'Consistency': sentence_group_logical,
        },
        'BASIC MESSAGES': {
            'Sentence': description,
            'Dense': len(data_item['misleading_frames']),
            'Complex': obj_number
        },
        'RATIONALE MESSAGES': {
            'Existence': pos_exist_r,
            'Uniqueness 1': pos_unique_r,
            'Uniqueness 2': misleading_exist_r,
            'Logic': sentence_logical_r,
            'Consistency': sentence_group_logical_r,
        }
    }
    # ==================================

    USER_TEXT = 'The **JSON MESSAGE** needs analysis is: "{}".'.format(json.dumps(JSON_MESSAGE, indent=4))
    llama_answer = response_generator_text(system_message, USER_TEXT, max_token=max_token, lock=lock)
    llama_answer = llama_answer.content
    score_list = re.findall(r"Score:\s*(\d+)", llama_answer)
    if len(score_list) > 0:
        score = int(score_list[-1])
    else:
        last_five_lines = llama_answer.split('\n')[-3:]
        all_numbers = [num for s in last_five_lines for num in re.findall(r"\d+", s)]
        if len(all_numbers) > 0:
            score = int(all_numbers[-1])
        else:
            score = -100

    data_item.update({job_key: score})
    return data_item
