import pandas as pd
from .admission import Admission
from .event import RadiologyEvent, LabEvent, MicrobiologyEvent
import numpy as np
import os
import json
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from tqdm import tqdm
import sys
from deprecated import deprecated
import copy
import random

input_max_char_length = os.environ.get("INPUT_MAX_CHAR_LENGTH", 30000)

def _admission_wise(patient: 'Patient') -> list:
    return patient.admission_wise()

def _prepare_prediction_list(patient: 'Patient') -> list:
    return patient.prepare_prediction_list()

class Patient:

    def __init__(self, subject_id):
        self.subject_id = subject_id

    def add_patient_info(self, conn):
        query = f"""
        SELECT * FROM mimiciv_hosp.patients WHERE subject_id = {self.subject_id}
        """
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        assert len(df) == 1, f"Expected 1 row, got {len(df)}"
        patient_info = df.iloc[0]
        self.gender = patient_info['gender']
        self.birth = patient_info['anchor_year'] - patient_info['anchor_age']
        self.dod = patient_info['dod']
        self.admission = Admission.get_admission_list(conn, self.subject_id)

    def to_dict(self):
        patient_dict = {}
        for attr, value in self.__dict__.items():
            if isinstance(value, list):
                patient_dict[attr] = [item.to_dict() for item in value]
            else:
                patient_dict[attr] = value
        return patient_dict

    @staticmethod
    def from_dict(patient_dict):
        patient = Patient(patient_dict['subject_id'])
        for key, value in patient_dict.items():
            if key == 'admission':
                patient.admission = [Admission.from_dict(adm) for adm in value]
            else:
                setattr(patient, key, value)
        return patient

    @staticmethod
    def get_subject_id_list(engine, limit=None):
        with engine.connect() as conn:
            query = """
            SELECT subject_id FROM mimiciv_hosp.patients
            """
            if limit is not None:
                query += f" LIMIT {limit}"
            df = pd.read_sql(query, conn)
            df = df.replace({np.nan: None})
            subject_id_list = df['subject_id'].tolist()
        return subject_id_list

    @staticmethod
    @deprecated
    def get_patient_list(path, max_workers=48, offset=0, batch_size=None):
        folders = os.listdir(path)
        json_list = []
        for folder in folders:
            files = os.listdir(os.path.join(path, folder))
            for file in files:
                if file.endswith('.json'):
                    json_path = os.path.join(path, folder, file)
                    json_list.append(json_path)

        if batch_size is not None:
            json_list = json_list[offset:offset + batch_size]
        else:
            json_list = json_list[offset:]

        def patient_from_json(json_path):
            with open(json_path, 'r') as f:
                patient_dict = json.load(f)
                return Patient.from_dict(patient_dict)

        patient_list = []
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            for patient in tqdm(executor.map(patient_from_json, json_list), total=len(json_list), desc="Loading patients"):
                patient_list.append(patient)
        return patient_list

    @staticmethod
    def _get_discharge_note_dict(patient_list):
        notes = {}
        for patient in patient_list:
            for adm in patient.admission:
                notes[adm.discharge_note] = adm
        return notes

    @staticmethod
    def _get_json_path_in_group(group_path):
        json_list = os.listdir(group_path)
        json_list = [os.path.join(group_path, file)
                     for file in json_list if file.endswith('.json')]
        return json_list

    @staticmethod
    def get_json_batch_list(path, batch_size=10000):
        folders = os.listdir(path)
        json_list = []

        with ThreadPoolExecutor(max_workers=48) as executor:
            for group_path in tqdm(executor.map(lambda x: os.path.join(path, x), folders), total=len(folders), desc="Loading JSON paths"):
                json_list.extend(Patient._get_json_path_in_group(group_path))
                
        json_list = sorted(json_list)

        print(f'{len(json_list)} patients found.')
        
        patient_batch_list = []
        for i in range(0, len(json_list), batch_size):
            patient_batch = json_list[i:i + batch_size]
            patient_batch_list.append(patient_batch)
        return patient_batch_list

    @staticmethod
    def get_patient_list_from_batch(patient_batch, max_workers=48):
        patient_list = []
        first_patient = os.path.split(patient_batch[0])[-1]
        last_patient = os.path.split(patient_batch[-1])[-1]
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            for patient in tqdm(executor.map(lambda x: Patient.from_dict(json.load(open(x, 'r'))), patient_batch), total=len(patient_batch), desc=f"Loading patients from {first_patient} to {last_patient}"):
                patient_list.append(patient)
        return patient_list

    @staticmethod
    def _discharge_note_length_filter(notes, length):
        ret_notes = dict()
        print(f'{len(notes)} notes before length filter')
        for note in notes.keys():
            if len(note) <= length:
                ret_notes[note] = notes[note]
        print(f'{len(ret_notes)} notes after length filter')
        return ret_notes
    
    @staticmethod
    def _discharge_note_filter(patient_list):
        print(f'{len(patient_list)} patients before discharge note filter')
        patient_with_discharge_note = set()
        for patient in patient_list:
            for adm in patient.admission:
                if not hasattr(adm, 'discharge_note') or adm.discharge_note is None:
                    if patient in patient_with_discharge_note:
                        patient_with_discharge_note.remove(patient)
                    break
                patient_with_discharge_note.add(patient)
        print(f'{len(patient_with_discharge_note)} patients after discharge note filter')
        return list(patient_with_discharge_note)

    @staticmethod
    def parse_disch(patient_list, enable_few_shot):
        
        patient_list = Patient._discharge_note_filter(patient_list)
        
        sys.path.append(os.path.abspath(os.path.join(
            os.path.dirname(__file__), "..", "..")))
        from inference.prompt_parse_disch import parse_disch

        notes = Patient._get_discharge_note_dict(patient_list)
        if len(notes) == 0:
            return [], []

        note_list = list(notes.keys())
        adms = list(notes.values())
        parsed_json_list = parse_disch(
            note_list, enable_few_shot=enable_few_shot)
        results = [notes[note].add_info(parsed_json) for note, parsed_json in zip(
            note_list, parsed_json_list)]
        successfully_parsed_adms = [adms[i] for i in range(len(results)) if results[i]]
        print(len(successfully_parsed_adms), 'admissions successfully parsed')
        successfully_parsed_patients = []
        for patient in patient_list:
            for adm in patient.admission:
                if adm in successfully_parsed_adms:
                    if patient not in successfully_parsed_patients:
                        successfully_parsed_patients.append(patient)
                else:
                    if patient in successfully_parsed_patients:
                        successfully_parsed_patients.remove(patient)
                    
        return results, successfully_parsed_patients

    @staticmethod
    def _get_radio_note_dict(patient_list):
        notes = {}
        for patient in patient_list:
            for adm in patient.admission:
                for event in adm.event_list:
                    if isinstance(event, RadiologyEvent):
                        notes[event.text] = event
        return notes

    @staticmethod
    def parse_radio(patient_list, enable_few_shot):
        
        sys.path.append(os.path.abspath(os.path.join(
            os.path.dirname(__file__), "..", "..")))
        from inference.prompt_parse_radio import parse_radio, modality_valid, modality_std
        
        notes = Patient._get_radio_note_dict(patient_list)
        
        if len(notes) == 0:
            return [], []
        
        note_list = list(notes.keys())
        radiology_events = list(notes.values())
        modality_list = parse_radio(
            note_list, enable_few_shot=enable_few_shot)
        
        results = [modality_valid(modality) for modality in modality_list]
        
        valid_count = sum(1 for result in results if result == "Valid")
        others_count = sum(1 for result in results if result == "Others")
        invalid_count = sum(1 for result in results if result == "Invalid")
        print(f"Valid: {valid_count}, Others: {others_count}, Invalid: {invalid_count}")
        
        results = [True if result == "Valid" else False for result in results]
        
        for i in range(len(results)):
            notes[note_list[i]].add_modality(modality_std(modality_list[i]))
        
        successfully_parsed_radiology_events = [radiology_events[i] for i in range(len(results)) if results[i]]
        successfully_parsed_patients = []
        
        for patient in patient_list:
            
            is_patient_fully_successful = True
            
            for adm in patient.admission:
                for event in adm.event_list:
                    if isinstance(event, RadiologyEvent):
                        if event not in successfully_parsed_radiology_events:
                            is_patient_fully_successful = False
                            break
                if not is_patient_fully_successful:
                    break
                
            if is_patient_fully_successful:
                successfully_parsed_patients.append(patient)
        
        return results, successfully_parsed_patients
    
    def prepare_description_dict_list(self, predicion_type):
        description_dict_list = []
        for adm in self.admission:
            description_dict_list.extend(adm.prepare_description_dict_list(predicion_type))
        return description_dict_list
    
    @staticmethod
    def prepare_description_dict_list_from_patient_list(patient_list, prediction_type):
        description_dict_list = []
        for patient in patient_list:
            description_dict_list.extend(patient.prepare_description_dict_list(prediction_type))
        return description_dict_list
    
    @staticmethod
    def _description_dict_length_filter(description_dict_list, length):
        ret_description_list = []
        print(f'{len(description_dict_list)} descriptions before length filter')
        with ThreadPoolExecutor(max_workers=48) as executor:
            for description in tqdm(executor.map(lambda x: x, description_dict_list), total=len(description_dict_list), desc="Filtering descriptions by length"):
                if len(str(description['description'])) <= length:
                    ret_description_list.append(description)
        print(f'{len(ret_description_list)} descriptions after length filter')
        return ret_description_list
    
    @staticmethod
    def prepare_description_list_from_patient_list(patient_list, prediction_type):
        description_dict_list = Patient.prepare_description_dict_list_from_patient_list(patient_list, prediction_type)
        
        sys.path.append(os.path.abspath(os.path.join(
            os.path.dirname(__file__), "..", "..")))
        from inference.prompt_desciption import get_description_text
        
        description_dict_list = Patient._description_dict_length_filter(description_dict_list, input_max_char_length)
        
        description_list = [str(description['description']) for description in description_dict_list]
        description_text_list = get_description_text(description_list)
        description_list = [
            {
                'description': description_dict['description'],
                'ground_truth': description_dict['ground_truth'],
                'total_record': description_dict['total_record'],
                'description_text': description_text,
            }
            for description_dict, description_text in zip(description_dict_list, description_text_list)
        ]
        return description_list
    
    def admission_wise(self) -> list:
        num_admissions = len(self.admission)
        if num_admissions == 0:
            return [self]
        elif num_admissions == 1:
            return [self]
        else:
            patient_list = []
            for i in range(num_admissions):
                new_patient = copy.deepcopy(self)
                new_patient.admission = [self.admission[i]]
                patient_list.append(new_patient)
            return patient_list
        
    @staticmethod
    def get_admission_wise_patient_list(patient_list: list, max_workers=48) -> list:
        admission_wise_patient_list = []
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            for patient in tqdm(executor.map(_admission_wise, patient_list), total=len(patient_list), desc="Getting admission-wise patients"):
                admission_wise_patient_list.extend(patient)
        return admission_wise_patient_list
    
    def prepare_prediction_list(self) -> list:
        prediction_list = []
        assert len(self.admission) == 1, "This method is only for admission-wise patients"
        for i in range(len(self.admission[0].event_list)):
            event = self.admission[0].event_list[i]
            if isinstance(event, RadiologyEvent) or isinstance(event, LabEvent) or isinstance(event, MicrobiologyEvent):
                info_pre = copy.deepcopy(self)
                if hasattr(info_pre, 'dod'):
                    delattr(info_pre, 'dod')
                if hasattr(info_pre.admission[0], 'diagnosis_list'):
                    delattr(info_pre.admission[0], 'diagnosis_list')
                if hasattr(info_pre.admission[0], 'discharge_note'):
                    delattr(info_pre.admission[0], 'discharge_note')
                if hasattr(info_pre.admission[0], 'progress_info'):
                    delattr(info_pre.admission[0], 'progress_info')
                if hasattr(info_pre.admission[0], 'discharge_info'):
                    delattr(info_pre.admission[0], 'discharge_info')
                info_pre.admission[0].event_list = info_pre.admission[0].event_list[:i]
                
                if isinstance(event, RadiologyEvent):
                    info_pred = copy.deepcopy(event)
                    delattr(info_pred, 'note_id')
                    delattr(info_pred, 'note_type')
                    delattr(info_pred, 'note_seq')
                    delattr(info_pred, 'text')
                    info_pred_gt = copy.deepcopy(event)
                elif isinstance(event, LabEvent):
                    info_pred = copy.deepcopy(event)
                    delattr(info_pred, 'specimen_id')
                    delattr(info_pred, 'lab_list')
                    info_pred_gt = copy.deepcopy(event)
                elif isinstance(event, MicrobiologyEvent):
                    info_pred = copy.deepcopy(event)
                    delattr(info_pred, 'microbiology_list')
                    info_pred_gt = copy.deepcopy(event)
                
                info_pre_dict = info_pre.to_dict()
                admission_info = info_pre_dict['admission'][0].pop('admission_info', None)
                if admission_info is not None:
                    info_pre_dict['admission'][0] = {
                        'admission_info': admission_info,
                        **info_pre_dict['admission'][0]
                    }
                
                prediction_list.append({
                    'info_pre': json.dumps(info_pre_dict, default=str),
                    "info_pred": json.dumps(info_pred.to_dict(), default=str),
                    'info_pred_gt': json.dumps(info_pred_gt.to_dict(), default=str),
                })
                
        return prediction_list
    
    @staticmethod
    def prepare_prediction_list_from_patient_list(patient_list: list, max_workers=48) -> list:
        prediction_list = []
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            for patient_pred in tqdm(executor.map(_prepare_prediction_list, patient_list), total=len(patient_list), desc="Preparing prediction list"):
                prediction_list.extend(patient_pred)
        
        return prediction_list