from copy import deepcopy
from collections import defaultdict


slots_with_canon_heads = [
    'symptom',
    'habit',
    'medical_history',
    'family_history',
    'medication',
    'medical_test',
    'exposure',
    'disease'
]

slots_without_canon_heads = [
    'occupation',
    'travel',
    'basic_information',
    'residence'
]

slot_descriptions = {
    'positive_symptom': 'Positive Symptoms (current symptoms experienced by the patient)',
    'negative_symptom': 'Negative Symptoms (symptoms not experienced by the patient)',
    'unknown_symptom': 'Unknown Symptoms (symptoms whose status is unknown)',
    'positive_habit': 'Positive Habits (habits the patient has)',
    'negative_habit': 'Negative Habits (habits the patient does not have)',
    'unknown_habit': 'Unknown Habits (habits whose status is unknown)',
    'positive_medical_history': 'Positive Past Medical History (previous medical conditions experienced by the patient)',
    'negative_medical_history': 'Negative Past Medical History (medical conditions the patient has not experienced)',
    'unknown_medical_history': 'Unknown Past Medical History (medical conditions whose status is unknown)',
    'positive_family_history': 'Positive Family History (medical conditions present in the patient\'s family)',
    'negative_family_history': 'Negative Family History (medical conditions not present in the patient\'s family)',
    'unknown_family_history': 'Unknown Family History (medical conditions whose status is unknown)',
    'positive_medication': 'Positive Medications (medications the patient is taking)',
    'negative_medication': 'Negative Medications (medications the patient is not taking)',
    'unknown_medication': 'Unknown Medications (medications whose status is unknown)',
    'avail_medical_test': 'Available Medical Tests (medical tests the patient has undergone)',
    'unavail_medical_test': 'Unavailable Medical Tests (medical tests the patient has not undergone)',
    'positive_exposure': 'Positive Exposures (exposures the patient has had)',
    'negative_exposure': 'Negative Exposures (exposures the patient has not had)',
    'unknown_exposure': 'Unknown Exposures (exposures whose status is unknown)',
    'occupation': 'Patient Occupation Details',
    'travel': 'Patient Travel History',
    'basic_information': 'Patient General Information',
    'residence': 'Patient Accommodation Details'
}

slot_attribute_info = {
    'symptom': {
        'status': 'is it likely that the patient is currently experiencing this symptom (possible answers - LIKELY, UNLIKELY)?',
        'onset': 'when did this symptom likely begin? (e.g., two days ago)?',
        'location': 'where on the body is this symptom most likely to be experienced (e.g., around the neck)?',
        'severity': 'rate the likely severity of this symptom on a scale of 1 to 10 (e.g., 7)?',
        'initiation': 'what is the most likely way for this symptom to first appear? (e.g., adrupty, slowly, etc)?',
        'duration': 'for how long is this symptom likely to last (e.g., 2 hours)?',
        'positive_characteristics': 'what would be a positive characteristics of this symptom (e.g., it is sharp)?',
        'negative_characteristics': 'what would be a negative characteristics of this symptom (e.g., it is not sharp)?',
        'alleviating_factor': 'name one factor that would likely make this symptom better (e.g., rest)?',
        'not_alleviating_factor': 'name one factor that does not make this symptom better (e.g., exercise)?',
        'aggravating_factor': 'name one factor would likely make this symptom worse (e.g., exercise)?',
        'not_aggravating_factor': 'name one factor that does not make this symptom worse (e.g., rest)?',
        'not_alleviating_aggravating_factor': 'what neither makes this symptom better nor worse (e.g., rest)?',
        'progression': 'what is likely way this symptom could change or progress? (possible answers - unchanged with time, gradually getting worse over time, gradually getting better over time, rapidly getting worse over time, rapidly getting better over time, fluctuating)?',
        'volume': 'what is the likely volume of this symptom (e.g., four table spoon of sputum)? Answer NA if not applicable to this symptom.',
        'color': 'what is the likely color of this symptom (e.g., yellowish sputum)? Answer NA if not applicable to this symptom.',
        'frequency': 'how frequently is this symptom likely to occur (e.g., once a day)?',
        'rash_swollen': 'is it likely that the rash is swollen (possible answers - yes, no)? Answer NA if not applicable to this symptom.',
        'legion_size': 'how likely is that the legion is larger than 1cm (possible answers - yes, no)? Answer NA if not applicable to this symptom.',
        'legion_peel_off': 'is it likely that the legion is peeling off (possible answers - yes, no)? Answer NA if not applicable to this symptom.',
        'itching': "is it likely that the patient's rash is itching (possible answers - yes, no)? Answer NA if not applicable to this symptom.",
    },
    'medical_history': {
        "status": "is it likely that the patient was diagnosed with this medical condition in the past (possible answers - LIKELY, UNLIKELY)?",
        "starting": "when was the patient likely diagnosed with (e.g., 2 years ago)?",
        "frequency": "how frequently does the patient likely experience this medical condition (e.g., once a month)?",
    },
    "family_history": {
        "status": "is it likely that the patient's family member has this medical condition (possible answers - LIKELY, UNLIKELY)?",
        "relation": "what is the likely relation of the patient with the family member who has this medical condition (e.g., father)?",
    },
    "habit": {
        "status": "is it likely that the patient has this habit (possible answers - LIKELY, UNLIKELY)?",
        "starting": "when did the patient likely start this habit (e.g., 2 years ago)?",
        "frequency": "how frequently does the patient likely engage in this habit (e.g., once a day)?",
    },
    "exposure": {
        "status": "is it likely that the patient was exposed to this factor (possible answers - LIKELY, UNLIKELY)?",
        "when": "when was the patient likely exposed to this factor (e.g., 2 years ago)?",
        "where": "where was the patient likely exposed to this factor (e.g., at home)?",
    },
    "medication": {
        "status": "is it likely that the patient is taking this medication (possible answers - LIKELY, UNLIKELY)?",
        "respone_to": "for what medical condition is the patient likely taking this medication (e.g., headache)?",
        "start": "when did the patient likely start taking this medication (e.g., 2 years ago)?",
        "frequency": "how frequently does the patient likely take this medication (e.g., once a day)?",
        "impact": "what is the likely impact of this medication on the patient (possible answer - yes, no)?",
    },
    "medical_test": {
        "status": "is it likely that the patient has undergone this medical test (possible answers - LIKELY, UNLIKELY)?",
        "when": "when was the patient likely undergone this medical test (e.g., 2 years ago)?",
    },
    "residence": {
        "value": "what is the likely residence of the patient (e.g., in a house)?",
    },
    "occupation": {
        "value": "what is the likely occupation of the patient (e.g., teacher)?",
    },
}


def expand_checks(checks):
    checks_to_search = []
    for chk in checks:
        if chk == 'ana_factors':
            checks_to_search.extend([
                'alleviating_factor', 'not_alleviating_factor',
                'aggravating_factor', 'not_aggravating_factor',
                'not_alleviating_aggravating_factor'
            ])
        elif chk == 'characteristics':
            checks_to_search.extend(['positive_characteristics', 'negative_characteristics'])
        else:
            checks_to_search.append(chk)

    return checks_to_search


def check_special(SLOT, query):
    """
    Check if the slot is a special slot and return True if it is.
    """
    if SLOT != 'symptom':
        return []

    val = query['value']
    locations = []
    for chk in query.get('checks', []):
        if chk['type'] == 'location':
            locations.extend(chk.get('values', []))

    if val == 'pain':
        expanded_queries = []
        mapping = {
            "abdomen": "abdominal pain",
            "temporal region": "headache",
            "ear": "earache",
            "chest": "chest pain",
            "face": "facial pain",
            "eye": "eye pain",
            "neck": "neck pain",
            "ear": "earache",
            "throat": "sore throat",
            "pharyngeal structure": "sore throat",
            "set of muscles": "myalgia",
        }
        for loc in locations:
            for kk, vv in mapping.items():
                if kk in loc:
                    tmp = deepcopy(query)
                    tmp['value'] = vv
                    expanded_queries.append(tmp)

        if len(expanded_queries) > 0:
            # We don't want to add the original query.
            return expanded_queries

        return None

    if val == 'swelling':
        expanded_queries = []
        if 'neck' in locations or 'lymph node' in locations or 'lymph nodes' in locations or 'throat' in locations:
            tmp = deepcopy(query)
            tmp['value'] = 'lymphadenopathy'
            expanded_queries.append(tmp)
        return expanded_queries

    pain_symptoms = [
        "abdominal pain", "headache",
        "chest pain", "facial pain",
        "sore throat", "pharyngitis",
        "myalgia", "earache",
        "eye pain", "neck pain",
        "nasal sinus pressure sensation",
        "chest tightness", "chest pressure"
    ]
    if val == 'abdominal pain':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['abdomen']}]

        return [tmp]

    if val == 'headache':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['temporal region']}]

        return [tmp]

    if val in ['chest pain', 'chest tightness', 'chest pressure']:
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['chest']}]

        return [tmp]

    if val == 'facial pain':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['face', "ear structure", "eye", "ear"]}]

        return [tmp]

    if val in ['sore throat', "pharyngitis"]:
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['throat', "pharyngeal structure"]}]

        return [tmp]

    if val == 'myalgia':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['set of muscles']}]

        return [tmp]

    if val == 'earache':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['ear', "ear structure"]}]

        return [tmp]

    if val == 'eye pain':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['eye']}]

        return [tmp]

    if val == 'neck pain':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['neck', "lateral part of the neck"]}]

        return [tmp]

    if val == 'nasal sinus pressure sensation':
        tmp = deepcopy(query)
        tmp['value'] = 'pain'
        tmp['checks'] = [{'type': 'location', 'values': ['nasal sinus']}]

        return [tmp]

    if val == 'lymphadenopathy':
        tmp = deepcopy(query)
        tmp['value'] = 'swellng'
        tmp['checks'] = [{'type': 'location', 'values': ['neck', "lymph node", "lymph nodes", "throat"]}]

        return [tmp]

    return None


def search_medical_slots(dialog_state, SLOT, query_item):
    assert SLOT in [
        'symptom', 'habit', 'medical_history', 'family_history',
        'medication', 'medical_test', 'exposure', 'disease'
    ]
    DST_KEYS = [f"{x}_{SLOT}" for x in ['positive', 'negative', 'unknown']]
    if SLOT == 'medical_test':
        DST_KEYS = [f"{x}_{SLOT}" for x in ['avail', 'unavail']]

    # We are assuming that all the attributes are present in the matched_entry.
    # This holds as we plan to impute them before the simulation starts.
    query = query_item['value']
    assert query not in ['all', 'other', 'general']
    checks = [x['type'] for x in query_item.get('checks', []) if 'type' in x]
    matched_slot, matched_entry = None, None

    for slot in DST_KEYS:
        if slot not in dialog_state:
            continue
        for entry2 in dialog_state[slot]:
            if entry2['value'] == query:
                matched_slot = slot
                matched_entry = deepcopy(entry2)
                break
        if matched_entry is not None:
            break

    if matched_entry is None:
        if SLOT == 'medical_test':
            return {'unavail_medical_test': [{'value': query}]}, False
        # elif SLOT == 'symptom':
        #     # This should only happen if doctor accidently inquires a symptom out of ontology.
        #     print(f'ERROR: Symptom {query} not found in ontology and DST {dialog_state}')
        #     return {'negative_symptom': [{'value': query}]}, False
        return {f"negative_{SLOT}": [{'value': query}]}, False

    # We found a match.
    ret = {'value': query}

    # Pain Specific Handling
    if (query in ['pain', 'swelling', 'erythema']) and ('location' in checks) and (matched_slot == 'positive_symptom') and ('location' in matched_entry):
        loc_check = [x for x in query_item['checks'] if x['type'] == 'location'][0]
        locs = loc_check.get('values', [])
        if len(locs) > 0 and len(matched_entry['location']) > 0:
            # We need to check if the location is present in the matched entry.
            # If not, we need to set the matched entry to "negative"
            matched_locs = [
                ttt
                for ttt in matched_entry['location']
                for loc in locs
                if (loc in ttt) or (ttt in loc)
            ]
            if len(matched_locs) == 0:
                # Nothing matched. So symptom is negative
                return {"negative_symptom": [ret]}, True

    if len(checks) > 0 and (matched_slot.startswith('positive_') or matched_slot.startswith('avail_')):
        # There are some checks to be done.
        # exp_checks = expand_checks(checks)
        # ret.update({chk: matched_entry[chk] for chk in exp_checks if chk in matched_entry})

        for chk in checks:
            echks = expand_checks([chk])
            found = False
            for echk in echks:
                if echk not in matched_entry:
                    continue
                ret[echk] = matched_entry[echk]
                found = True

            if not found:
                if chk == 'ana_factors':
                    ret['alleviating_factor'] = "NOT SURE"
                    ret['aggravating_factor'] = "NOT SURE"
                elif chk == 'characteristics':
                    ret['positive_characteristics'] = "NOT SURE"
                    ret['negative_characteristics'] = "NOT SURE"
                else:
                    ret[chk] = "NOT SURE"
    
    return {matched_slot: [ret]}, True


def search_dst_for_actions(dialog_state, SLOT, query_item):
    assert query_item.get('value') not in [None, 'all', 'other', 'general']

    if SLOT in [
        'symptom', 'habit', 'medical_history', 'family_history',
        'medication', 'medical_test', 'exposure', 'disease'
    ]:
        oresults, status = search_medical_slots(dialog_state, SLOT, query_item)
        if status:
            # Answer found
            return oresults

        # Check expanded set.
        equeries = check_special(SLOT, query_item)
        if equeries is None:
            return oresults

        for eq in equeries:
            rr, st = search_medical_slots(dialog_state, SLOT, eq)
            if st:
                return rr
        return oresults

    else:
        if SLOT in dialog_state:
            return deepcopy({SLOT: dialog_state[SLOT]})
    
    return dict()


def merge_single(dialog_state, slot, entry):
    related_slots = []
    old_slots = list(dialog_state)
    if any(x in slot for x in ['positive', 'negative', 'unknown', 'avail', 'unavail']):
        tmp = slot.split('_', 1)[-1]
        related_slots = [x for x in old_slots if (x != slot) and (tmp in x)]

    for tmp in related_slots:
        for ii in range(len(dialog_state[tmp])):
            if dialog_state[tmp][ii]['value'] == entry['value']:
                dialog_state[tmp][ii] = None
        dialog_state[tmp] = [x for x in dialog_state[tmp] if x is not None]

    # If slot is not in dialog state add it and return
    if slot not in dialog_state:
        dialog_state[slot] = [entry]
        return dialog_state

    # Check matches
    match_idx = []
    for ii, oentry in enumerate(dialog_state[slot]):
        if oentry['value'] == entry['value']:
            match_idx.append(ii)
    assert len(match_idx) <= 1, f"Multiple matches found for slot {slot} value {entry['value']}."

    # If slot is in dialog state but entry is not add it in return
    if len(match_idx) == 0:
        dialog_state[slot].append(entry)
        return dialog_state
    
    # If slot is in the dialog state and entry is present then update it.
    match_idx = match_idx[0]
    for kk in entry:
        if kk == 'value':
            continue
        if kk not in dialog_state[slot][match_idx]:
            dialog_state[slot][match_idx][kk] = entry[kk]
            continue
        if type(entry[kk]) == list:
            dialog_state[slot][match_idx][kk].extend(entry[kk])
            dialog_state[slot][match_idx][kk] = list(set(dialog_state[slot][match_idx][kk]))
            continue
        dialog_state[slot][match_idx][kk] = entry[kk]

    return dialog_state


def merge_states(old_state, new_state):
    slots_without_canon_heads = [
        'occupation',
        'travel',
        'basic_information',
        'residence'
    ]

    for slot, entries in new_state.items():
        if slot in slots_without_canon_heads:
            if slot not in old_state:
                old_state[slot] = deepcopy(entries)
                continue

            tmp = deepcopy(old_state[slot][0])
            tmp.update(entries[0])
            old_state[slot] = [tmp]
        else:
            for entry in entries:
                old_state = merge_single(deepcopy(old_state), slot, deepcopy(entry))

    ret = dict()
    for slot, entries in old_state.items():
        if len(entries) == 0:
            continue
        ret[slot] = []
        for entry in entries:
            tentry = dict()
            for key, value in entry.items():
                if value is None:
                    continue
                if len(value) == 0:
                    continue
                tentry[key] = value
            ret[slot].append(tentry)

    return ret


def create_clinical_note(dialog_state):
    canon_slots = [x for x in slots_with_canon_heads if x != 'disease']
    func = lambda x: ' '.join([y.capitalize() for y in x.split('_')])
    func2 = lambda x: x.lower().replace('_', ' ')
    lines = []
    for ii, slot in enumerate(dialog_state.keys()):
        # lines.append(f"## {func(slot)}")
        if slot not in slot_descriptions:
            continue 
        lines.append(f"## {slot_descriptions[slot]}")
        flag = any(x in slot for x in canon_slots)

        for jj, entry in enumerate(dialog_state[slot]):
            if entry.get('value') in ['other', 'general', 'all']:
                continue

            if flag:
                tmp = f'{jj + 1}. {func2(entry["value"])}'
                lines.append(tmp)

            for k, v in entry.items():
                if flag and k == 'value':
                        continue

                if type(v) == list:
                    lines.append(f'  - {func2(k)}: {", ".join([func2(z) for z in v])}')
                else:
                    lines.append(f'  - {func2(k)}: {func2(v)}')

        lines.append('')

    return '\n'.join(lines).strip()




def parse_output_to_nlu(output_str):
    """
    Parse the output string from markdown format into a structured JSON.
    
    For slots with canonical headers:
    - Groups attributes by canonical value and status
    - Creates JSON with status_slot keys containing value and other attributes
    
    For slots without canonical headers:
    - Creates direct attribute mappings in the JSON
    
    Args:
        output_str (str): String containing Python list tuples in markdown format
        
    Returns:
        dict: Structured JSON representation of the parsed data
    """
    # Clean the markdown formatting to extract just the Python list
    lines = output_str.strip().split('\n')
    # Remove markdown code block indicators if present
    if lines[0].startswith('```'):
        lines = lines[1:-1] if lines[-1].startswith('```') else lines[1:]
    
    # Join the cleaned lines and evaluate the Python list
    clean_str = ''.join(lines).strip()
    if not clean_str:
        return {}
        
    # Handle potential evaluation errors safely
    try:
        parsed_tuples = eval(clean_str)
    except (SyntaxError, ValueError) as e:
        print(f"Error parsing output: {e}")
        return {}
    
    result = {}
    
    # Process each tuple from the parsed list
    for item in parsed_tuples:
        # Handle slots with canonical headers (4 elements)
        if len(item) == 4 and item[0] in slots_with_canon_heads:
            slot, canon_head, attribute, attribute_value = item
            if canon_head == "none":
                continue
            # Determine if this is positive or negative status
            if attribute == "status":
                status_prefix = attribute_value  # 'positive' or 'negative'
            else:
                # Find the status for this canon_head by looking at other tuples
                status_prefix = None
                for search_item in parsed_tuples:
                    if (len(search_item) == 4 and 
                        search_item[0] == slot and 
                        search_item[1] == canon_head and 
                        search_item[2] == "status"):
                        status_prefix = search_item[3]
                        break
                
                # Default to positive if status not found
                if status_prefix is None:
                    status_prefix = "positive"
            
            # Construct the key with status prefix
            key = f"{status_prefix}_{slot}"
            
            # Initialize the list for this key if it doesn't exist
            if key not in result:
                result[key] = []
            
            # Find the entry for this canon_head or create a new one
            entry = None
            for existing_entry in result[key]:
                if existing_entry.get("value") == canon_head:
                    entry = existing_entry
                    break
            
            if entry is None:
                entry = {"value": canon_head}
                result[key].append(entry)
            
            # Add the attribute if it's not 'status'
            if attribute != "status":
                entry[attribute] = attribute_value
                
        # Handle slots without canonical headers (3 elements)
        elif len(item) == 3 and item[0] in slots_without_canon_heads:
            slot, attribute, attribute_value = item
            
            # Initialize the dictionary for this slot if it doesn't exist
            if slot not in result:
                result[slot] = []
                
            # Add the attribute to the slot's dictionary
            result[slot].append({attribute: attribute_value})
    
    return result