import os, json, glob
from utils.myutils import errorMessage

def log_one_case(doctor, patient, log_path):
    q = patient.get_questions()
    a = patient.get_answers()
    if len(q) != len(a):
        errorMessage(f"[ERROR] Patient {patient.case_id}: Q/A length mismatch.")
        return

    interaction_data = {
        'case_id': patient.case_id,
        'category': patient.category,
        'task': patient.task,
        'choices': patient.choices,
        'answer': patient.ground_truth,
        'interactions': list(zip(q, a)),
        'truncated': patient.truncated,
        'interaction_rounds': len(q),
        'doctor_token': doctor.client.get_tokens(),
        'doctor_time': doctor.time,
        'patient_time': patient.time,
        'atom_information': patient.atom_information,
        'additional_info': {
            **(doctor.additional_info or {}),
            **(patient.additional_info or {})
        }
    }
    
    with open(log_path, 'a', encoding='utf-8') as file:
        json.dump(interaction_data, file, ensure_ascii=False)
        file.write('\n')
        

def log_one_case_static(data, log_path, limit_info):
    static_data = {
        'case_id': data['case_id'],
        'category': data.get('category', None),
        'task': data['task'],
        'choices': data['choices'],
        'answer': data['answer'],
        'limit_info': limit_info,
    }
    
    with open(log_path, 'a', encoding='utf-8') as file:
        json.dump(static_data, file, ensure_ascii=False)
        file.write('\n')

def merge_process_logs(folder, merged_filename="log_merged.jsonl"):
    merged_path = os.path.join(folder, merged_filename)
    with open(merged_path, "w", encoding="utf-8") as outfile: # merge all logs into one file
        log_files = sorted(glob.glob(os.path.join(folder, "log_*.jsonl")))
        for log_file in log_files:
            if "merged" in os.path.basename(log_file):
                continue
            with open(log_file, "r", encoding="utf-8") as infile:
                for line in infile:
                    outfile.write(line)
        for log_file in log_files: # remove individual logs after merging
            if "merged" not in os.path.basename(log_file):
                os.remove(log_file)
    return merged_path


def save_scores(save_folder, averages: dict, score_name="average_scores.json", details_name="detailed_scores.json"):
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
        
    with open(os.path.join(save_folder, score_name), "w", encoding="utf-8") as f:
        json.dump(averages.get("avg_scores", {}), f, indent=2, ensure_ascii=False)
        
    with open(os.path.join(save_folder, details_name), "w", encoding="utf-8") as f:
        json.dump(averages.get("details", {}), f, indent=2, ensure_ascii=False)
        
    # systemMessage(f"Saved average scores to {path}")