import json
import pandas as pd
import os
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
import base64
from scipy.io import loadmat

def TID13_data():
    mos_list = []
    img_paths = []
    with open('/root/IQA/IQA-Agent/datasets/tid2013/mos_with_names.txt', 'r') as f:
    # with open('/root/IQA/IQA-Agent/datasets/tid2013/1.txt', 'r') as f:
        for line in f:
            mos, image_name = line.strip().split()  
            distortion_type = image_name.split('_')[1]  # 获取失真类型
            cer_id = image_name.split('_')[0][1:]
            # if (cer_id in test_id) and (distortion_type in distortions):
            # if (cer_id in test_id):
            img_path = "/root/IQA/IQA-Agent/datasets/tid2013/distorted_images/"+image_name
            ref_path = "/root/IQA/IQA-Agent/datasets/tid2013/reference_images/" +image_name.split("_")[0].upper() + ".BMP"
            mos_list.append(float(mos))  
            img_paths.append([img_path,ref_path])
                # img_paths.append(img_path)

    mos_list = mos_list
    img_paths = img_paths
    return img_paths,mos_list

def LIVE_data():
    df = pd.read_csv('/root/IQA/IQA-Agent/datasets/LIVEIQA_release2/LIVE_IQA_DMOS.csv') 
    img_paths = [[
    os.path.join('/root/IQA/IQA-Agent/datasets/LIVEIQA_release2', row["Distorted Image"]),
    os.path.join('/root/IQA/IQA-Agent/datasets/LIVEIQA_release2', "refimgs", row["Reference Image"])  # 参考图像存储在 refimgs 目录
    ] for _, row in df.iterrows()] 
    # img_paths = df[["Distorted Image", "Reference Image"]].values.tolist() 
    mos_list = df["DMOS"].tolist()  
    return img_paths,mos_list

def QBench_data():
    import json
    import os

    json_path = "/root/IQA/IQA-Agent/datasets/llvisionqa_test.json"
    image_root = "/root/IQA/IQA-Agent/datasets/llvisionqa_images"

    img_paths, questions, choices_list, correct_choices,types, concerns = [], [], [], [],[],[]

    with open(json_path, 'r') as f:
        data = json.load(f)
        data = data[:100]

    for item in data:
        image = item.get("image", item.get("img_path", ""))
        image_path = os.path.join(image_root, image)

        question = item["question"]
        choices = item["candidates"]
        correct_ans = item["correct_ans"]
        if correct_ans in choices:
            correct_choice = chr(65 + choices.index(correct_ans))
        else:
            correct_choice = None

        choice_text = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        # query = question + "\n" + choice_text

        img_paths.append([image_path])
        questions.append(question)
        choices_list.append(choices)
        correct_choices.append(correct_choice)
        types.append(item["type"])        
        concerns.append(item["concern"])  
    return img_paths, questions, choices_list, correct_choices,types, concerns

def BID_data():
    mat_path = '/root/IQA/IQA-Agent/datasets/BID/BID/imdb.mat'
    image_root = '/root/IQA/IQA-Agent/datasets/BID/BID/ImageDatabase'

    mat = loadmat(mat_path)
    names = mat['images'][0][0]['name'][0]      # 1x586 cell
    labels = mat['images'][0][0]['label'].squeeze().tolist()  # 586x1

    img_paths = [[os.path.join(image_root, str(name[0]))] for name in names]
    mos_list = labels
    return img_paths, mos_list

def AGIQA_data():
    csv_path = '/root/IQA/IQA-Agent/datasets/AGIQA-3K/AGIQA-3k-Database/data.csv'
    image_root = '/root/IQA/IQA-Agent/datasets/AGIQA-3K/AGIQA_images'

    df = pd.read_csv(csv_path, sep=None, engine='python')  

    img_paths = [[os.path.join(image_root, name)] for name in df['name']]
    mos_list = df['mos_quality'].tolist()

    return img_paths, mos_list

def QBench_description_data():
    import json
    import os

    json_path = "/root/IQA/IQA-Agent/datasets/llvisionqa_test.json"
    image_root = "/root/IQA/IQA-Agent/datasets/llvisionqa_images"

    img_paths, questions, choices_list, correct_choices,types, concerns = [], [], [], [],[],[]

    with open(json_path, 'r') as f:
        data = json.load(f)
        data = data[:100]

    for item in data:
        image = item.get("image", item.get("img_path", ""))
        image_path = os.path.join(image_root, image)

        question = item["question"]
        choices = item["candidates"]
        correct_ans = item["correct_ans"]
        if correct_ans in choices:
            correct_choice = chr(65 + choices.index(correct_ans))
        else:
            correct_choice = None

        choice_text = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        # query = question + "\n" + choice_text

        img_paths.append([image_path])
        questions.append(question)
        choices_list.append(choices)
        correct_choices.append(correct_choice)
        types.append(item["type"])        
        concerns.append(item["concern"])  
    return img_paths, questions, choices_list, correct_choices,types, concerns
