import os, json
from multiprocessing import Pool
from tqdm import tqdm
from utils.config_loader import ConfigLoader
from dialogue_manager import DialogueManager
from utils.logger import *
from evaluators.evaluator import Evaluator
from utils.myutils import systemMessage, department_map, select_basic_info

def run_one_case(args): # dealing with one patient
    pid = os.getpid()
    config_dict, patient_data, doctor_class, patient_class, base_folder = args

    config = ConfigLoader.from_dict(config_dict)
    if doctor_class.__name__ == "DoctorApi":
        if department_map[patient_data["category"]] is not None:
            doctor = doctor_class(config, patient_data["task"], patient_data["choices"], patient_data["category"])
        else:
            return None
    else:
        doctor = doctor_class(config, patient_data["task"], patient_data["choices"])
    patient = patient_class(config, patient_data)

    DialogueManager(config).interact(doctor, patient) 

    log_path = os.path.join(base_folder, f"log_{pid}.jsonl")

    log_one_case(doctor, patient, log_path)

    return {
        "pid": pid,
        "case_id": patient_data["case_id"]
    }
    
def run_one_case_static(args):
    patient_data, base_folder = args
    
    limit_info = patient_data["atomized_information"]
    
    limit_info = select_basic_info(limit_info)

    pid = os.getpid()
    log_path = os.path.join(base_folder, f"log_{pid}.jsonl")
    log_one_case_static(patient_data, log_path, limit_info)

    return {
        "pid": pid,
        "case_id": patient_data["case_id"]
    }


class MAQuE:
    def __init__(self, config, doctor_class, patient_class, folder):
        self.config = config
        
        self.doctor_class = doctor_class
        self.patient_class = patient_class
        self.dialogue_manager = DialogueManager(self.config)
        
        self.mode = self.config.get('evaluation', 'mode')
        
        systemMessage("MAQuE folder: " + folder)
        self.folder = folder
        os.makedirs(self.folder, exist_ok=True)
        
        if "interact" in self.mode:
            interaction_config = {
                'doctor': self.config.get_section('doctor'),
                'patient': self.config.get_section('patient'),
            }
            interaction_config['doctor']['doctor_list'] = self.doctor_class.__name__
            interaction_config['patient']['patient_list'] = self.patient_class.__name__
            with open(os.path.join(folder, 'meta.json'), 'w', encoding='utf-8') as file:
                json.dump(interaction_config, file, indent=2, ensure_ascii=False)
            systemMessage("Interaction meta saved to " + os.path.join(folder, 'meta.json'))
            
        if 'static' in self.mode or 'evaluate' in self.mode:
            self.evaluator = Evaluator(self.config.get_section("evaluation"), folder)
        
        self.patients_data = []
        if 'interact' in self.mode or 'static' in self.mode:
            with open(self.config.get('patient', 'data_path'), 'r') as file:
                for line in file:
                    data = json.loads(line)
                    if data.get('category', None) is not None:
                        self.patients_data.append(data)
                    else:
                        self.patients_data.append(data)

            start_index = int(self.config.get_section('patient').get('start_index', 0))
            test_num = int(self.config.get_section('patient').get('test_num', 10))
            self.patients_data = self.patients_data[start_index:start_index + test_num]

            systemMessage("Number of session: " + str(len(self.patients_data)))

    def run(self):
        
        systemMessage("Mode: " + self.mode)
        max_processes = int(self.config.get('evaluation', 'max_processes_patient'))

        if 'interact' in self.mode: # run interaction
            systemMessage(f"Running in parallel with {max_processes} processes.")
            args_list = [
                (
                    self.config.to_dict(),
                    pd,
                    self.doctor_class,
                    self.patient_class,
                    self.folder
                )
                for pd in self.patients_data
            ]

            results = []
            with Pool(processes=max_processes) as pool, tqdm(total=len(args_list), desc="Doctor-Patient Interaction") as pbar:
                for res in pool.imap_unordered(run_one_case, args_list):
                    if res is not None:
                        results.append(res)
                        pbar.update(1)

            merged_log = merge_process_logs(self.folder)
            systemMessage(f"Interaction logs saved to {self.folder}")

        if 'evaluate' in self.mode: # evaluate based on previous logs
            systemMessage(f"Evaluating from logs in {self.folder}")
            if 'merged_log' not in locals():
                merged_log = os.path.join(self.folder, 'log_merged.jsonl')
                
            scores = self.evaluator.evaluate_from_log(merged_log)

            systemMessage("Average results for " + str(self.config.get('patient', 'test_num')) + " patients:")
            systemMessage(json.dumps(scores['avg_scores'], indent=4))
            
        if 'static' in self.mode: # static evaluation without interaction
            args_list = [
                (
                    pd,
                    self.folder
                )
                for pd in self.patients_data
            ]

            results = []
            with Pool(processes=max_processes) as pool, tqdm(total=len(args_list), desc="Static Logging") as pbar:
                for res in pool.imap_unordered(run_one_case_static, args_list):
                    results.append(res)
                    pbar.update(1)
                    
            merged_log = merge_process_logs(self.folder)
            systemMessage(f"Static logs saved to {self.folder}")
            
            scores = self.evaluator.evaluate_from_log(merged_log)

            systemMessage("Average results for " + str(self.config.get('patient', 'test_num')) + " patients:")
            systemMessage(json.dumps(scores['avg_scores'], indent=4))
    
        if 'scores' in locals():
            return scores['avg_scores']
        else:
            return f"Interaction logs saved to {self.folder}"