import json
from typing import List, Tuple
import numpy as np
import os

def split_data_by_accuracy(origin_data: List[dict]) -> tuple[List[dict], List[dict]]:
    correct_data: List[dict] = []
    incorrect_data: List[dict] = []
    for item in origin_data:
        if item['is_correct'] == True:
            correct_data.append(item)
        else:
            incorrect_data.append(item)
    return correct_data, incorrect_data


def summary_result(origin_data: List[dict], correct_data: List[dict], incorrect_data: List[dict], verbose = False) -> dict:
    res = {}
    def show_stat(data: List[dict], title: str) -> None:
        min_len = int(1e9)
        max_len = -1
        avg_len = 0
        middle_len: int
        for item in data:
            min_len = min(min_len, len(item['entropy']))
            max_len = max(max_len, len(item['entropy']))
            avg_len = avg_len + len(item['entropy'])
        avg_len = avg_len / len(data) if data else 0
        if verbose:
            print(f"{title} length - Min: {min_len}, Max: {max_len}, Avg: {avg_len}, Middle length: {len(data[len(data)//2]['entropy'])}")
            print(f"correct number: {len(correct_data)}//{len(origin_data)}, accuracy: {len(correct_data)/len(origin_data)}")
            print("-----------------------------------------------------------------")
        res[title] = {
            "min": min_len,
            "max": max_len,
            "avg": avg_len,
            "middle": len(data[len(data)//2]['entropy'])
        }
    
    show_stat(origin_data, "origin data")
    show_stat(correct_data, "correct data")
    show_stat(incorrect_data, "incorrect data")
    return res
    
def interpolate_data(data: List[dict], stat_res: dict, name:str, target_mode: str = "max_length") -> List:
    if target_mode == "max_length":
        target_length = stat_res[name]["max"]
    elif target_mode == "avg_length":
        target_length = stat_res[name]["avg"]
    elif target_mode == "min_length":
        target_length = stat_res[name]["min"]
    elif target_mode == "middle_length":
        target_length = stat_res[name]["middle"]
    else:
        raise ValueError("target_mode should be one of ['max_length', 'avg_length', 'min_length', 'middle_length']")
    
    interpolated_data = []
    for series in data:
        origin_x = np.linspace(0, 1, len(series['entropy']))
        target_x = np.linspace(0, 1, int(target_length))
        interpolated_serie = np.interp(target_x, origin_x, series['entropy'])
        interpolated_data.append(interpolated_serie)
    
    final_avg_data = np.mean(interpolated_data, axis=0).tolist()
    return final_avg_data
    

    


def main():
    entropy_data_path = '/path/to/data'
    data_list = ['file_name']
    for data_name in data_list:
        with open(os.path.join(entropy_data_path, data_name), 'r', encoding='utf-8') as f:
            origin_data = json.load(f)
    
    correct_data, incorrect_data = split_data_by_accuracy(origin_data)
    stat_res: dict = summary_result(origin_data, correct_data, incorrect_data, verbose=False)
    print(stat_res)



if __name__ == '__main__':
    main()