import pandas as pd
from .diagnosis import Diagnosis
from .event import Event, RadiologyEvent
from .info import AdmissionInfo, ProgressInfo, DischargeInfo
import numpy as np
import json

class Admission:
    
    def __init__(self, hadm_id):
        self.hadm_id = hadm_id
        
    def add_event_list(self, conn):
        self.event_list = Event.get_event_list(conn, self.hadm_id)
        
    def add_diagnosis_list(self, conn):
        self.diagnosis_list = Diagnosis.get_diagnosis_list(conn, self.hadm_id)
        
    def add_discharge_note(self, conn):
        
        query = f"""
        SELECT 
        
        note_id,
        "note_type",
        "note_seq",
        charttime,
        storetime,
        text
        
        FROM mimiciv_note.discharge
        WHERE hadm_id = {self.hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        assert len(df) <= 1, f"Expected 0 or 1 row, got {len(df)}"
        if len(df) == 1:
            discharge_note_info = df.iloc[0]
            self.discharge_note = discharge_note_info['text']
        
    def to_dict(self):
        admission_dict = {}
        for attr, value in self.__dict__.items():
            if isinstance(value, list):
                admission_dict[attr] = [item.to_dict() for item in value]
            elif isinstance(value, AdmissionInfo):
                admission_dict[attr] = value.to_dict()
            elif isinstance(value, ProgressInfo):
                admission_dict[attr] = value.to_dict()
            elif isinstance(value, DischargeInfo):
                admission_dict[attr] = value.to_dict()
            else:
                admission_dict[attr] = value
        return admission_dict
    
    @staticmethod
    def get_admission_list(conn, subject_id):
        query = f"""
        SELECT 
        
        hadm_id,
        admittime,
        dischtime,
        deathtime,
        admission_type,
        admit_provider_id,
        admission_location,
        discharge_location,
        insurance,
        language,
        marital_status,
        race,
        edregtime,
        edouttime,
        hospital_expire_flag
        
        FROM mimiciv_hosp.admissions
        WHERE subject_id = {subject_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        admission_list = []
        for _, row in df.iterrows():
            admission = Admission(row['hadm_id'])
            admission.language = row['language']
            admission.marital_status = row['marital_status']
            admission.race = row['race']
            admission.add_diagnosis_list(conn)
            admission.add_discharge_note(conn)
            admission.add_event_list(conn)
            admission_list.append(admission)
        return admission_list
    
    @staticmethod
    def from_dict(admission_dict):
        admission = Admission(admission_dict['hadm_id'])
        for key, value in admission_dict.items():
            if key == 'event_list':
                admission.event_list = [Event.from_dict(event) for event in value]
            elif key == 'diagnosis_list':
                admission.diagnosis_list = [Diagnosis.from_dict(diagnosis) for diagnosis in value]
            elif key == 'admission_info':
                admission.admission_info = AdmissionInfo.from_dict(value)
            elif key == 'progress_info':
                admission.progress_info = ProgressInfo.from_dict(value)
            elif key == 'discharge_info':
                admission.discharge_info = DischargeInfo.from_dict(value)
            else:
                setattr(admission, key, value)
        return admission
    
    @staticmethod
    def _dict_format(reference_dict, current_dict):
        for key, value in reference_dict.items():
            if key in current_dict:
                if isinstance(value, str) or isinstance(value, type(None)):
                    reference_dict[key] = current_dict[key]
                elif isinstance(value, dict):
                    Admission._dict_format(value, current_dict[key])
                else:
                    raise ValueError(f"Unsupported type for key {key}: {type(value)}")
            else:
                raise KeyError(f"Key {key} not found in current_dict")
        return reference_dict
    
    @staticmethod
    def _parsed_json_to_dict(parsed_json):
        try:
            parsed_json = json.loads(parsed_json)
        except json.JSONDecodeError:
            return None
        dict_format = {
            'Admission Info': {
                'Allergies': None,
                'Chief Complaint': None,
                'History of Present Illness': None,
                'Past Medical History': None,
                'Social History': None,
                'Family History': None,
                'Admission Physical Exam': None,
                'Medications on Admission': None,
            },
            'Progress Info':{
                'Major Surgical or Invasive Procedure': None,
                'Pertinent Results': None,
                'Brief Hospital Course': None,
            },
            'Discharge Info': {
                'Discharge Condition': None,
                'Discharge Physical Exam': None,
                'Discharge Instructions': None,
                'Discharge Diagnosis': None,
                'Discharge Medications': None,
                'Followup Instructions': None,
            }
        }
        try:
            dict_format = Admission._dict_format(dict_format, parsed_json)
        except Exception as e:
            return None
        return dict_format
    
    @staticmethod
    def _lowercase_dict_keys(parsed_dict):
        dict_format = {
            'admission_info': {
                'allergies': parsed_dict['Admission Info']['Allergies'],
                'chief_complaint': parsed_dict['Admission Info']['Chief Complaint'],
                'history_of_present_illness': parsed_dict['Admission Info']['History of Present Illness'],
                'past_medical_history': parsed_dict['Admission Info']['Past Medical History'],
                'social_history': parsed_dict['Admission Info']['Social History'],
                'family_history': parsed_dict['Admission Info']['Family History'],
                'admission_physical_exam': parsed_dict['Admission Info']['Admission Physical Exam'],
                'medications_on_admission': parsed_dict['Admission Info']['Medications on Admission'],
            },
            'progress_info': {
                'major_surgical_or_invasive_procedure': parsed_dict['Progress Info']['Major Surgical or Invasive Procedure'],
                'pertinent_results': parsed_dict['Progress Info']['Pertinent Results'],
                'brief_hospital_course': parsed_dict['Progress Info']['Brief Hospital Course'],
            },
            'discharge_info': {
                'discharge_condition': parsed_dict['Discharge Info']['Discharge Condition'],
                'discharge_physical_exam': parsed_dict['Discharge Info']['Discharge Physical Exam'],
                'discharge_instructions': parsed_dict['Discharge Info']['Discharge Instructions'],
                'discharge_diagnosis': parsed_dict['Discharge Info']['Discharge Diagnosis'],
                'discharge_medications': parsed_dict['Discharge Info']['Discharge Medications'],
                'followup_instructions': parsed_dict['Discharge Info']['Followup Instructions'],
            }
        }
        return dict_format
    
    def add_info(self, parsed_json):
        parsed_dict = Admission._parsed_json_to_dict(parsed_json)
        if parsed_dict is None:
            return False
        parsed_dict = Admission._lowercase_dict_keys(parsed_dict)
        self.admission_info = AdmissionInfo.from_dict(parsed_dict['admission_info'])
        self.progress_info = ProgressInfo.from_dict(parsed_dict['progress_info'])
        self.discharge_info = DischargeInfo.from_dict(parsed_dict['discharge_info'])
        return True
    
    def prepare_description_dict_list(self, prediction_type):
        assert prediction_type in ['radiology'] or isinstance(prediction_type, list), f"Unknown prediction type: {prediction_type}"
        description_dict_list = []
        if prediction_type == 'radiology':
            i = 0
            while i < len(self.event_list):
                if isinstance(self.event_list[i], RadiologyEvent):
                    description_dict = self.to_dict()
                    description_dict['event_list'] = description_dict['event_list'][:i]
                    description_dict.pop('diagnosis_list')
                    description_dict.pop('discharge_note') if 'discharge_note' in description_dict else None
                    description_dict.pop('progress_info') if 'progress_info' in description_dict else None
                    description_dict.pop('discharge_info') if 'discharge_info' in description_dict else None
                    description_dict = {
                        'description': description_dict,
                        'ground_truth': [event.to_dict() for event in self.event_list[i:]],
                        'total_record': self.to_dict(),
                    }
                    description_dict_list.append(description_dict)
                i += 1
        elif isinstance(prediction_type, list):
            assert all((isinstance(i, float) or isinstance(i, int)) and 0 <= i <= 1 for i in prediction_type)
            index_list = [int(i * len(self.event_list)) for i in prediction_type]
            for i in index_list:
                description_dict = self.to_dict()
                description_dict['event_list'] = description_dict['event_list'][:i]
                
                description_dict.pop('diagnosis_list')
                description_dict.pop('discharge_note') if 'discharge_note' in description_dict else None
                description_dict.pop('progress_info') if 'progress_info' in description_dict else None
                description_dict.pop('discharge_info') if 'discharge_info' in description_dict else None
                
                description_dict = {
                    'description': description_dict,
                    'ground_truth': [event.to_dict() for event in self.event_list[i:]],
                    'total_record': self.to_dict(),
                }
                description_dict_list.append(description_dict)
        return description_dict_list
    