import argparse
import os
import sys
sys.path.append("")
import pandas as pd
import numpy as np
from typing import List, Dict, Any
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 定义常量
BASE_OUTPUT_DIR = ""

DATASET = "CELEBA" # VGGFACE OR CELEBA
SUBJECT_ID_SET = [
    "n000050", "n000061", "n000076", "n000088", "n000097", "n000104", "n000138", "n000145", "n000154", "n000170", "n000176", "n000181", "n000187", "n000215", "n000221", "n000228", "n000238",
    "n000057", "n000063", "n000080", "n000089", "n000098", "n000105", "n000139", "n000146", "n000161", "n000171", "n000179", "n000184", "n000188", "n000217", "n000223", "n000234", "n000243",
    "n000058", "n000068", "n000087", "n000090", "n000103", "n000110", "n000142", "n000150", "n000164", "n000172", "n000180", "n000185", "n000190", "n000220", "n000225", "n000236",
]

# SUBJECT_ID_SET = [
#     '5', '17', '25', '26', '34', '44', '47', '49', '63', '65', '67', '77', '80',
#     '95', '103', '104', '108', '111', '112', '121', '122', '124', '125', '128', '129', '143',
#     '146', '158', '161', '162', '175', '177', '179', '180', '181', '182', '183', '188',
#     '195', '196', '198', '203', '204', '205', '206', '208', '213', '218', '223', '228'
# ]

SUBJECT_ID_SET = [
    '5', '17', '25', '26', '34', '44', '47', '49', '63', '65', '67', '77', '80',
    '95', '103', '104', '108', '111', '112', '121', '122', '124', '125', '128', '129', '143',
    '146', '158', '161', '162', '175'
]

SUBDIRS = ["sampledC"]


PROMPT_SET = ["a_photo_of_sks_person", "a_dslr_portrait_of_sks_person", "a_photo_of_sks_person_looking_at_the_mirror"]
# PROMPT_SET = ["a_photo_of_person", "a_dslr_portrait_of_person", "a_photo_of_person_looking_at_the_mirror"]
# PROMPT_SET = ["a_photo_of_<sks-person>", "a_dslr_portrait_of_<sks-person>", "a_photo_of_<sks-person>_looking_at_the_mirror"]

def get_unique_filepath(base_filepath: str) -> str:
    """
    获取一个唯一的文件路径，如果文件已存在则添加序号
    
    例如：
    - test_brisque.csv -> test_brisque_1.csv (如果原文件存在)
    - test_brisque_1.csv -> test_brisque_2.csv (如果_1也存在)
    """
    if not os.path.exists(base_filepath):
        return base_filepath
    
    # 分离文件名和扩展名
    base_name, ext = os.path.splitext(base_filepath)
    counter = 1
    
    while True:
        new_filepath = f"{base_name}_{counter}{ext}"
        if not os.path.exists(new_filepath):
            return new_filepath
        counter += 1

class TestResultManager:
    def __init__(self):
        # 创建一个空的DataFrame，使用MultiIndex来表示层次结构
        self.results_df = pd.DataFrame(
            index=pd.MultiIndex.from_product(
                [SUBJECT_ID_SET, SUBDIRS, PROMPT_SET],
                names=['subject_id', 'subdir', 'prompt']
            )
        )
        
    def add_result(self, subject_id: str, subdir: str, prompt: str, 
                  metrics: Dict[str, Any]) -> None:
        """
        添加一个测试结果
        
        参数:
        - subject_id: 测试主体ID
        - subdir: 子目录名称
        - prompt: 提示词
        - metrics: 包含测试指标的字典，例如 {'loss': 0.5, 'accuracy': 0.95}
        """
        # 验证输入
        if subject_id not in SUBJECT_ID_SET:
            raise ValueError(f"Invalid subject_id: {subject_id}")
        if subdir not in SUBDIRS:
            raise ValueError(f"Invalid subdir: {subdir}")
        if prompt not in PROMPT_SET:
            raise ValueError(f"Invalid prompt: {prompt}")
            
        # 为每个指标创建或更新列
        for metric_name, value in metrics.items():
            self.results_df.loc[(subject_id, subdir, prompt), metric_name] = value
    
    def get_result(self, subject_id: str, subdir: str, prompt: str) -> pd.Series:
        """获取特定组合的测试结果"""
        return self.results_df.loc[(subject_id, subdir, prompt)]
    
    def get_subject_results(self, subject_id: str) -> pd.DataFrame:
        """获取某个subject的所有结果"""
        return self.results_df.xs(subject_id, level='subject_id')
    
    def get_subdir_results(self, subdir: str) -> pd.DataFrame:
        """获取某个子目录的所有结果"""
        return self.results_df.xs(subdir, level='subdir')
    
    def get_prompt_results(self, prompt: str) -> pd.DataFrame:
        """获取某个prompt的所有结果"""
        return self.results_df.xs(prompt, level='prompt')
    
    def save_to_csv(self, filepath: str, avoid_overwrite: bool = True) -> str:
        """
        保存结果到CSV文件
        
        参数:
        - filepath: 目标文件路径
        - avoid_overwrite: 是否避免覆盖现有文件
        
        返回:
        - 实际保存的文件路径
        """
        if avoid_overwrite:
            actual_filepath = get_unique_filepath(filepath)
        else:
            actual_filepath = filepath
            
        self.results_df.to_csv(actual_filepath)
        
        if actual_filepath != filepath:
            print(f"文件 {filepath} 已存在，保存为 {actual_filepath}")
        else:
            print(f"结果已保存到 {actual_filepath}")
            
        return actual_filepath

    def get_average_prompt_results(self) -> pd.DataFrame:
        """沿prompt索引对结果进行平均"""
        return self.results_df.groupby('prompt').mean()

    @classmethod
    def load_from_csv(cls, filepath: str) -> 'TestResultManager':
        """从CSV文件加载结果"""
        manager = cls()
        manager.results_df = pd.read_csv(
            filepath, 
            index_col=['subject_id', 'subdir', 'prompt']
        )
        return manager


# from evaluations.brisques import eval_brisque
# from evaluations.ser_fiq import eval_ser_fiq
# from evaluations.ism_fdfr import eval_ism_fdfr

def batch_evaluation(metric: str, input_path: str, **kwargs) -> float:
    """批量评估给定指标"""
    if metric == 'brisque':
        from evaluations.brisques import eval_brisque
        return eval_brisque(input_path, **kwargs)
    elif metric == 'ism_fdfr':
        from evaluations.ism_fdfr import eval_ism_fdfr
        return eval_ism_fdfr(input_path, **kwargs)
    elif metric == 'ser_fiq':
        from evaluations.ser_fiq import eval_ser_fiq
        return eval_ser_fiq(input_path, **kwargs)

def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description='Batch evaluation')
    parser.add_argument('--metric', choices=['brisque', 'ism_fdfr', 'ser_fiq'], help='path to datadir')
    args = parser.parse_args(input_args)
    return args

if __name__ == "__main__":
    args = parse_args()
    # * 读取
    # test_manager = TestResultManager.load_from_csv("test_brisque.csv")
    # res = test_manager.get_subdir_results("baseline_pert0")
    # print(res)

    # * 批量评估
    test_manager = TestResultManager()
    # 在测试循环中
    for i, subject_id in enumerate(SUBJECT_ID_SET):
        for subdir in SUBDIRS:
            for prompt in PROMPT_SET:
                
                # 执行测试
                input_path = os.path.join(BASE_OUTPUT_DIR, subject_id, subdir, "results", prompt)
                if not os.path.exists(input_path):
                    print(f"Input path {input_path} does not exist. Skipping...")
                    continue
                if DATASET == "VGGFACE":
                    emb_path = os.path.join("", subject_id, "set_C")
                elif DATASET == "CELEBA":
                    emb_path = os.path.join("", subject_id, "set_C")
                else:
                    raise Exception()

                # ism, fdr = eval_ism_fdfr(input_path, [emb_path])
                # result = eval_brisque(input_path)
                if args.metric == 'ism_fdfr':
                    kwargs = {"emb_dirs": [emb_path]}
                    ism, fdfr = batch_evaluation(args.metric, input_path, **kwargs)
                    test_manager.add_result(
                        subject_id=subject_id,
                        subdir=subdir,
                        prompt=prompt,
                        metrics={
                            'ism': ism.item() if ism else None,
                            'fdfr': fdfr,
                            # ... 其他指标
                        }
                    )
                else:
                    kwargs = {}
                    result = batch_evaluation(args.metric, input_path, **kwargs)
                
                    # 存储结果
                    test_manager.add_result(
                        subject_id=subject_id,
                        subdir=subdir,
                        prompt=prompt,
                        metrics={
                            args.metric: result,
                            # 'fdfr': fdr,
                            # ... 其他指标
                        }
                    )
        # 每处理完两个个体后保存
        if (i + 1) % 2 == 0:
            test_manager.save_to_csv(f"test_{args.metric}.csv", avoid_overwrite=False)
    
    # 最终保存
    final_filepath = test_manager.save_to_csv(f"test_{args.metric}.csv", avoid_overwrite=False)
    print(f"最终结果保存在: {final_filepath}")
