from typing import Literal, Callable, Optional, Any
import os
import re
import json
from functools import cache
from tqdm import tqdm
from collections import defaultdict
import random
import shutil
from copy import deepcopy
from olym_gen.utils.utils import get_logger, retrieve_id_from_name, UNKNOWN_FIELD

logger = get_logger()

def filter_by_check(check_file_path: str, threshold: float, greater_or_less: Literal['greater', 'less'], save_path: str) -> None:
    """
    Given a dir with checked report json files, filter the QA pairs based on the check results.
    NOTE: this function is not used now.
    Args:
        check_file_path: str: The path to the directory containing the checked report json files.
        threshold: float: keep the QA pairs with check score greater or less than the threshold, the score is the average score of all check generations for the QA pair.
        greater_or_less: Literal['greater', 'less']: whether to keep the QA pairs with check score greater or less than the threshold.
        save_path: str: The path to save the filtered results.
    """

    count = defaultdict(int)
    correct = defaultdict(int)
    valid_json = 0

    # 1. first run through the json files and record them into one dict
    return_dict = {}

    file_list = os.listdir(check_file_path)
    for file_name in file_list:
        # only keep the final name
        file_name = os.path.split(file_name)[-1]
        try:
            problem_index, proof_index, generation_index = retrieve_id_from_name(file_name)
            valid_json += 1
            qa_name = f'{problem_index}_{proof_index}'
        except ValueError as e:
            # not a json file, skip it
            continue
        from olym_gen.generator.problem_proof_generator import UNKNOWN_INDEX
        assert proof_index != UNKNOWN_INDEX, "The proof index should not be in the name."
        count[file_name] += 1
        with open(os.path.join(check_file_path, file_name), 'r', encoding='UTF-8') as f:
            data = json.load(f)
            if not data['pass_check']:
                continue
            if isinstance(data["check_result"], str):
                continue
            if data["check_result"]["proof_correct"]:
                proof_correct = True
            else:
                proof_correct = False

        if qa_name not in return_dict:
            return_dict[qa_name] = {
                'question': data['question'],
                'checked_proof': data['checked_proof'],
                'field': data.get('field', UNKNOWN_FIELD),
                'thinkings': [data['thinking']],
                'score': -1.0,
                'source': {
                    'pass': [], # pass the check and labeled as correct
                    'fail': [], # fail the check and labeled as incorrect
                },
            }

        return_dict[qa_name]['source']['pass' if proof_correct else 'fail'].append(file_name)

    # 2. calc the score
    for qa_name, qa_data in return_dict.items():
        # calculate the average score
        score = len(qa_data['source']['pass']) / (len(qa_data['source']['pass']) + len(qa_data['source']['fail']))
        qa_data['score'] = score

    # 3. filter the results based on the threshold
    if greater_or_less not in ['greater', 'less']:
        raise ValueError(f"greater_or_less should be 'greater' or 'less', but got {greater_or_less}.")
    if greater_or_less == 'greater':
        return_dict = {k: v for k, v in return_dict.items() if v['score'] > threshold}
    else:
        return_dict = {k: v for k, v in return_dict.items() if v['score'] < threshold}

    qa_num = sum(len(v['source']['pass']) + len(v['source']['fail']) for v in return_dict.values())

    # 3. save the filtered results to a jsonl file
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w', encoding='UTF-8') as f:
        for _, value in return_dict.items():
            f.write(json.dumps(value, ensure_ascii=False) + '\n')
    logger.info(f"Filtered results saved to {save_path}. Total read {len(file_list)} files with {valid_json} valid json, kept {qa_num} QA pairs with {len(return_dict)} questions with score {greater_or_less} than threshold {threshold}.")

def create_random_sample(source_path: str, target_path: str, sample_ratio: float = 0.05, random_seed: int = 42, filter_fn: Optional[Callable[[dict[str,Any]], bool]] = None, sample_num: int | None = None) -> None:
    """
    Sample a portion of data from source_path and save them to target_path as JSON files.
    
    Args:
        source_path: str: Path to a .jsonl file or a directory containing .json files.
        target_path: str: Directory path to save the sampled data as JSON files.
        sample_ratio: float: Portion of data to sample (default: 0.1).
        random_seed: int: Random seed for reproducibility (default: 42).
        filter_fn: Optional[Callable[[dict[str, any]], bool]]: Function to filter the data before sampling (default: None).
        sample_num: int | None: Number of samples to take (default: None, which means take according to sample_ratio).
    """

    random.seed(random_seed)

    # Create target directory
    os.makedirs(target_path, exist_ok=True)

    if os.path.isfile(source_path) and source_path.endswith('.jsonl'):
        # Handle .jsonl file
        with open(source_path, 'r', encoding='UTF-8') as f:
            lines = f.readlines()

        data_list = []

        for i, line in enumerate(lines):
            data = json.loads(line)
            data_list.append((data, i))

        # filter lines
        if filter_fn:
            data_list = [data for data in data_list if filter_fn(data[0])]

        # Sample lines
        sample_size = max(int(len(data_list) * sample_ratio), min(sample_num if sample_num else 0, len(data_list)))

        sampled_data = random.sample(data_list, sample_size)

        # Write each line as a separate JSON file
        for i, (data, problem_index) in enumerate(sampled_data):
            output_file = os.path.join(target_path, f"problem_{problem_index}.json")
            with open(output_file, 'w', encoding='UTF-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)

        logger.info(f"Sampled {sample_size}/{len(data_list)} entries from {source_path} to {target_path}")

    elif os.path.isdir(source_path):
        # Handle directory of json files
        json_files = [f for f in os.listdir(source_path) if f.endswith('.json')]

        if filter_fn:
            json_files = [f for f in json_files if filter_fn(json.loads(open(os.path.join(source_path, f)).read()))]

        data_dict = defaultdict(list)

        for f in json_files:
            try:
                problem_index, proof_index, generate_index = retrieve_id_from_name(f)
            except Exception as e:
                logger.error(f"Error retrieving IDs from filename {f}: {e}")
                continue
            data_dict[(problem_index, proof_index)].append(f)

        json_files = [random.choice(v) for v in data_dict.values()]

        if len(json_files) == 0:
            logger.warning(f"No valid JSON files found in {source_path}")
            return

        # Sample files
        sample_size = max(int(len(json_files) * sample_ratio), min(sample_num if sample_num else 0, len(json_files)))

        sampled_files = random.sample(json_files, sample_size)

        # Copy sampled files to target
        for file_name in sampled_files:
            shutil.copy2(
                os.path.join(source_path, file_name),
                os.path.join(target_path, file_name)
            )

        logger.info(f"Sampled {sample_size}/{len(json_files)} files from {source_path} to {target_path}")

    else:
        raise ValueError(f"Source path {source_path} must be a .jsonl file or a directory containing .json files")

def sample_from_dir(source_dir: str, target_dir: str, sample_ratio: float = 0.1, random_seed: int = 42, filter_fn: Optional[Callable[[dict[str, Any]], bool]] = None, sample_num: int | None = None, dir_filter: Optional[Callable[[str], bool]] = None) -> None:
    """
    Sample a portion of data from each subdirectory separately in a directory of subdirectories containing JSON files and save them to a target directory in corresponding subdirectories.

    Args:
        source_dir: str: Path to the source directory containing subdirectories with .json files.
        target_dir: str: Path to the target directory to save the sampled data.
        sample_ratio: float: Portion of data to sample (default: 0.1).
        random_seed: int: Random seed for reproducibility (default: 42).
        filter_fn: Optional[Callable[[dict[str, any]], bool]]: Function to filter the data before sampling (default: None).
        sample_num: int | None: Number of samples to take (default: None, which means take according to sample_ratio).
        dir_filter: Optional[Callable[[str], bool]]: Function to filter the subdirectories (default: None).
    """
    random.seed(random_seed)

    # Create target directory
    os.makedirs(target_dir, exist_ok=True)

    for dir_path, dir_names, file_names in os.walk(source_dir):
        # Compute relative path to maintain directory structure
        if dir_filter and not dir_filter(dir_path):
            continue

        json_files = [f for f in file_names if f.endswith('.json')]

        if json_files == []:
            continue

        rel_path = os.path.relpath(dir_path, source_dir)
        if os.path.basename(rel_path) != "data":
            rel_path = os.path.join(rel_path, "data")
        target_subdir = os.path.join(target_dir, rel_path)
        os.makedirs(target_subdir, exist_ok=True)

        create_random_sample(
            source_path=dir_path,
            target_path=target_subdir,
            sample_ratio=sample_ratio,
            random_seed=random_seed,
            filter_fn=filter_fn,
            sample_num=sample_num,
        )

def shuffle_data(source_dirs: list[str], target_dir: str, dir_filter: Optional[Callable[[str], bool]] = None):
    '''
    Shuffle JSON files from source directories and save them to a target directory. Store the correspondence relations in a separate file.
    
    Args:
        source_dirs: List[str]: List of source directories containing JSON files.
        target_dir: str: Path to the target directory to save the shuffled data.
        dir_filter: Optional[Callable[[str], bool]]: Function to filter the source directories.
    '''
    # 1. Create target directory
    data_dir = os.path.join(target_dir, 'data')
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    # 2. Save correspondence dictionary
    correspondence = {}
    count = 0
    if dir_filter is None:
        dir_filter = lambda x: [file for file in os.listdir(x) if file.endswith('.json')] != []

    # 3. Traverse each source directory
    for source_dir in source_dirs:
        if not dir_filter(source_dir):
            logger.debug(f"Skipping source directory: {source_dir}")
            continue
        logger.info(f"Processing source directory: {source_dir}")
        # Get all JSON files in the source directory
        json_files = [f for f in os.listdir(source_dir) if f.endswith('.json')]
        logger.info(f"Found {len(json_files)} JSON files")

        # 4. Process each extracted data file
        for json_file in json_files:
            source_path = os.path.join(source_dir, json_file)

            # Copy the file to the target folder
            new_file_name = json_file.split('.')[0] + '_' + str(count) + '.json'
            count += 1
            target_path = os.path.join(data_dir, new_file_name)
            shutil.copy(source_path, target_path)

            # Save correspondence
            correspondence[target_path] = {
                "original_dir": source_dir,
                "original_file": json_file
            }

    # 5. Save correspondence to a file
    with open(os.path.join(target_dir, 'correspondence.json'), 'w') as f:
        json.dump(correspondence, f, indent=4)

    logger.info(
        f"Data extracted and shuffled. Correspondence saved in {os.path.join(target_dir, 'correspondence.json')}"
    )
    logger.info(f"Total files extracted: {count}")

@cache

def list_check(check_dir: str, problem_retrieve: Callable[dict[str, Any], str] | None = None, proof_retrieve: Callable[dict[str, Any], str] | None = None, result_retrieve: Callable[dict[str, Any], bool] | None = None) -> dict[str, list[bool]]:
    """
    List the check results for all source files in the check directory.

    Args:
        check_dir: The directory containing the check JSON files.

    Returns:
        A dictionary mapping each source file to its list of correct and incorrect check results.
    """
    results = defaultdict(list)
    if not problem_retrieve:
        problem_retrieve = lambda x: x.get('question', '')
    if not proof_retrieve:
        proof_retrieve = lambda x: x.get('checked_proof', '')
    if not result_retrieve:
        result_retrieve = lambda x: x.get('check_result', {}).get('proof_correct', None)

    # Load the check JSON files
    for check_file in tqdm(os.listdir(check_dir)):
        if not check_file.endswith('.json'):
            continue

        with open(os.path.join(check_dir, check_file), 'r', encoding='utf-8') as f:
            check_data = json.load(f)

        problem = problem_retrieve(check_data)
        proof = proof_retrieve(check_data)
        key = f"{problem}___{proof}"
        try:
            result = result_retrieve(check_data)
            assert result is not None, f"Result is None in {check_file}"
            results[key].append(result)
        except Exception as e:
            logger.debug(f"Error processing {check_file}: {e}")

    return results

def _retrieve_check(source_file: str, check_dir: str, problem_retrieve: Callable[dict[str, Any], str] | None = None, proof_retrieve: Callable[dict[str, Any], str] | None = None, check_problem_retrieve: Callable[dict[str, Any], str] | None = None, check_proof_retrieve: Callable[dict[str, Any], str] | None = None, check_result_retrieve: Callable[dict[str, Any], bool] | None = None) -> tuple[int, int]:
    """
    Retrieve the check information for a specific source file.

    Args:
        source_file: The path to the source JSON file.
        check_dir: The directory containing the check JSON files.

    Returns:
        A tuple containing the number of correct and incorrect check results.
    """
    correct_count = 0
    incorrect_count = 0

    if problem_retrieve is None:
        problem_retrieve = lambda x: x.get('question', '')
    if proof_retrieve is None:
        proof_retrieve = lambda x: x.get('new_solution', '')

    # Load the check JSON files
    check_results = list_check(check_dir, check_problem_retrieve, check_proof_retrieve, check_result_retrieve)

    data = json.load(open(source_file, 'r', encoding='utf-8'))

    problem = problem_retrieve(data)
    proof = proof_retrieve(data)

    key = f"{problem}___{proof}"
    if key in check_results:
        correct_count = check_results[key].count(True)
        incorrect_count = check_results[key].count(False)

    return (correct_count, incorrect_count)

def retrieve_check(source_dir: str, check_dir: str, problem_retrieve: Callable[dict[str, Any], str] | None = None, proof_retrieve: Callable[dict[str, Any], str] | None = None, check_problem_retrieve: Callable[dict[str, Any], str] | None = None, check_proof_retrieve: Callable[dict[str, Any], str] | None = None, check_result_retrieve: Callable[dict[str, Any], bool] | None = None) -> dict[str, tuple[int, int]]:
    """
    Retrieve the check information for all source files in a directory.

    Args:
        source_dir: The directory containing the source JSON files.
        check_dir: The directory containing the check JSON files.

    Returns:
        A dictionary mapping each source file to its tuple of correct and incorrect check results.
    """
    results = {}
    for file in os.listdir(source_dir):
        if not file.endswith('.json'):
            continue
        source_file = os.path.join(source_dir, file)
        correct_count, incorrect_count = _retrieve_check(
                source_file,
                check_dir,
                problem_retrieve,
                proof_retrieve,
                check_problem_retrieve,
                check_proof_retrieve,
                check_result_retrieve
            )
        results[source_file] = (correct_count, incorrect_count)

    with open(os.path.join(source_dir, os.pardir, 'check_results.json'), 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)

    summary = {}
    for value in set(results.values()):
        summary[str(value)] = list(results.values()).count(value)
    logger.info(f"Check results summary: \n{json.dumps(summary, ensure_ascii=False, indent=4)}")

    return results
