import os
import re
import string
from pathlib import Path
from PIL import Image

import emoji
import torch.distributed as dist


# === Variables ===

# OpenAI API
OPENAI_API_KEY  = None
OPENAI_BASE_URL = None

# Model parameters
IGNORE_INDEX = -100

AUDIO_MAX_LENGTH = 448
LLM_MAX_LENGTH = 1536

# Qwen2 special tokens
QWEN2_START_TOKEN = "<|im_start|>"
QWEN2_END_TOKEN   = "<|im_end|>"
QWEN2_IMAGE_TOKEN = "<|image_pad|>"

EXPERIMENT_BASE_DIR = "./"
IMAGE_FOLDER = f"{EXPERIMENT_BASE_DIR}/datasets/LLaVA-Pretrain"
DATASET_BASE_DIR = f"{EXPERIMENT_BASE_DIR}/datasets/local"
MODEL_DIR = f"{EXPERIMENT_BASE_DIR}/model_weights"
CACHE_DIR = f"{EXPERIMENT_BASE_DIR}/cache/huggingface/datasets"


LIBRITTS_CONFIG = {
    'train': {
        'clean': ['train.clean.100', 'train.clean.360'],
        'other': ['train.other.500'],
    },
    'dev': {
        'clean': ['dev.clean'],
    },
    'test': {
        'clean': ['test.clean'],
    }
}


LORA_TRAINABLE_MODULES = [
    'q_proj',
    'v_proj',
    'k_proj',
    'o_proj',
    'gate_proj',
    'down_proj',
    'up_proj',
]

# === Variables End ===


# === OpenAudioBench Evaluation ===

def extract_score_alpaca_eval(text: str):
    """Extract score from text"""
    # [[number]]
    match = re.search(r"\[\[(\d+)\]\]", text)
    if match:
        score = int(match.group(1))
        # ensure score is within [1, 10]
        if score < 1 or score > 10:
            score = 1
        return score
    else:
        # [number]
        match = re.search(r"\[(\d+)\]", text)
        if match:
            score = int(match.group(1))
            # ensure score is within [1, 10]
            if score < 1 or score > 10:
                score = 1
            return score
        return 1


def extract_score_llama_questions(text: str):
    # [Tt]he score is \[(Correct|Incorrect)\]
    score = re.findall(r"[Tt]he score is \[(Correct|Incorrect)\]", text)
    if len(score) == 1 and score[0] == 'Correct':
        return 1
    elif len(score) == 1 and score[0] == 'Incorrect':
        return 0

    # [Correct|Incorrect]
    score = re.findall(r"\[([Cc]orrect|[Ii]ncorrect)\]", text)
    if score[0] == 'Correct':
        return 1
    else:
        return 0


def extract_score_trivia_qa_and_web_questions(text: str):
    # extract content within code block
    extracted = re.findall(r"```json(.*?)```", text, re.DOTALL)

    # if there is a code block, extract it
    if len(extracted) != 0:
        text = extracted[0]
    
    try:
        eval_js = eval(text)
        return 1 if eval_js['judgment'] == 'correct' else 0
    except:
        return 0


def get_next_level(path, target_dir):
    """get next level name"""
    path_obj = Path(path)
    parts = path_obj.parts
    
    try:
        target_index = parts.index(target_dir)
        if target_index + 1 < len(parts):
            return parts[target_index + 1]
        else:
            return None
    except ValueError:
        return None


OPENAUDIOBENCH_ATTRIBUTES = {
    "alpaca_eval": ["instruction", "output", extract_score_alpaca_eval],
    "llama_questions": ["Questions", "Answer", extract_score_llama_questions],
    "trivia_qa": ["question", "answer_normalized_value", extract_score_trivia_qa_and_web_questions],
    "web_questions": ["question", "answers", extract_score_trivia_qa_and_web_questions],
}

# === OpenAudioBench Evaluation End ===


# === VoiceBench Evaluation ===

def extract_score_voicebench(texts: str, eval_type: str) -> int:
    if eval_type == 'qa':
        yes_cnt = 0
        no_cnt = 0
        for text in texts:
            text = text.replace(string.puctuation, '')
            for word in text.lower().split():
                if word == 'yes':
                    yes_cnt += 1
                elif word == 'no':
                    no_cnt += 1
        return True if yes_cnt > no_cnt else False
    elif eval_type == 'open':
        # Define the regular expression pattern to match the rating in the format [[number]]
        pattern = r"\[\[(\d+)\]\]"

        scores = []
        for text in texts:
            score = None
            try:
                score = float(text)
            except:
                # Search for the pattern in the LLM output
                match = re.search(pattern, text)

                if match:
                    score = int(match.group(1))
            if score is not None:
                scores.append(score)
        return 1 if len(scores) == 0 else sum(scores) / len(scores)

# === VoiceBench Evaluation End ===


def print_rank_0(message: str) -> None:
    if not dist.is_initialized() or dist.get_rank() == 0:
        print(message)


def get_text_model_path(text_model_type: str) -> str:
    """Get text model path from audio model type"""

    if text_model_type == 'qwen2_7b_instruct':
        text_model_path = os.path.join(EXPERIMENT_BASE_DIR, 'models', 'Qwen2.5-7B-Instruct')
    elif text_model_type == 'qwen2_3b_instruct':
        text_model_path = os.path.join(EXPERIMENT_BASE_DIR, 'models', 'Qwen2.5-3B-Instruct')
    elif text_model_type == 'qwen2_05b_instruct':
        text_model_path = os.path.join(EXPERIMENT_BASE_DIR, 'models', 'Qwen2.5-0.5B-Instruct')
    else:
        raise ValueError(f'Invalid text model type: {text_model_type}')

    return text_model_path


def get_whisper_model_path(whisper_model_type: str) -> str:
    """Get whisper model path from whisper model type"""

    if whisper_model_type == 'large_v3_turbo':
        whisper_model_path = os.path.join(EXPERIMENT_BASE_DIR, 'models', 'whisper-large-v3-turbo')
    else:
        raise ValueError(f'Invalid whisper model type: {whisper_model_type}')

    return whisper_model_path


def remove_emoji(text: str) -> str:
    """Remove emoji from text"""
    return emoji.replace_emoji(text, replace='').strip()


def remove_unmatched_quotes(text: str) -> str:
    quote_pairs = {
        '"': '"',
        '“': '”',
        
        # could be used in 's or 't , so we can't remove them
        # "'": "'",
        # '‘': '’',
    }

    stack = []
    result = []

    for ch in text:
        if ch in quote_pairs:
            # left quote
            if stack and quote_pairs.get(stack[-1], None) == ch:
                # matched
                result.append(ch)
                stack.pop()
            elif ch in quote_pairs.values() and ch not in quote_pairs:
                # right quote
                continue
            else:
                # push left quote to stack
                stack.append(ch)
                result.append(ch)
        elif ch in quote_pairs.values():
            # right quote
            if stack and quote_pairs.get(stack[-1]) == ch:
                result.append(ch)
                stack.pop()
            else:
                # no match
                continue
        else:
            result.append(ch)

    # remove unmatched quotes
    final_text = ''.join(result)
    for q in stack:
        final_text = final_text.replace(q, '')

    # remove quotes that are at the beginning and end
    if final_text.startswith('"') and final_text.endswith('"') and '"' not in final_text[1:-1]:
        final_text = final_text[1:-1]
    if final_text.startswith('\'') and final_text.endswith('\'') and '\'' not in final_text[1:-1]:
        final_text = final_text[1:-1]

    return final_text


def preprocess_text_for_training(text: str, _rm_emoji: bool = True, _rm_hyphen: bool = False) -> str:
    """Preprocess text for automatic speech recognition training."""

    text = text.strip()

    if not text:
        return "."

    text = remove_unmatched_quotes(text)

    # Remove emoji
    if _rm_emoji:
        text = remove_emoji(text)

    # Remove hyphen (note the minus sign should be escaped)
    if _rm_hyphen:
        text = re.sub(r'\-', ' ', text)

    # Remove colon at the end
    if text.endswith(':') or text.endswith('：'):
        text = text[:-1].strip()

    text = text.strip()

    if not text:
        return "."

    if text[-1] not in ['.', '?', '!', '。', '！', '？', '\'', '"']:
        text += "."

    return text


def extract_and_clean_state_dict(full_sd: dict,
                                 old_prefix: str = "audio_model.",
                                 new_prefix: str = "") -> dict:
    """Extract a submodule's state from a complete state_dict and clean the prefix.

    - full_sd: complete state_dict (keys are "model.audio_model.xxx")
    - old_prefix: original prefix
    - new_prefix: target prefix, usually "model."

    Returns the new state dict, keys are "model.xxx", suitable for sub_mod.from_pretrained/load_state_dict.
    """
    new_sd = {}
    for k, v in full_sd.items():
        if k.startswith(old_prefix):
            # Extract the suffix
            suffix = k[len(old_prefix):]
            new_key = new_prefix + suffix
            new_sd[new_key] = v
    if not new_sd:
        raise ValueError(f"No keys found with prefix '{old_prefix}' in the checkpoint.")
    return new_sd


def expand2square(pil_img, background_color):
    """Expand image to square, instead of cropping"""
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result


def get_generation_output_path(audio_model_path: str, dataset_name: str, subset_name: str, prompt_type: str = "", ratio: float = 1.0) -> str:

    path_list = audio_model_path.split('/')
    if 'outputs' in path_list:
        path_start_index = path_list.index('outputs') + 1
    else:
        path_start_index = len(path_list) - 2

    if ratio == 1.0:
        result_path = os.path.join(f'outputs/{dataset_name}/{subset_name}', '/'.join(path_list[path_start_index:]), f'model_outputs{prompt_type}.jsonl')
    else:
        result_path = os.path.join(f'outputs/{dataset_name}/{subset_name}', '/'.join(path_list[path_start_index:]), f'model_outputs{prompt_type}_{ratio}.jsonl')
    if not os.path.exists(result_path):
        os.makedirs(os.path.dirname(result_path), exist_ok=True)
    else:
        with open(result_path, 'r', encoding='utf-8') as file:
            for line in file:
                if line.strip():
                    print(f'{result_path} is not empty, please delete it first!')
                    exit(0)

        print(f'{result_path} is empty and will be overwritten!')

    return result_path
