import json
from copy import deepcopy


def invert_pol_output(ret):
    pol = []
    for entry in ret:
        try:
            action = invert_entry(entry)
        except:
            action = None

        if action is None:
            continue

        pol.append(action)

    return pol


def invert_entry(entry):
    if len(entry) == 1:
        return {'action': entry['action']}

    action = entry['action']
    slot = entry['slot']
    value = entry['value']
    ret = {'action': action, slot: [{'value': value}]}

    if len(entry) == 3:
        return ret

    if 'check' in entry:
        ret[slot][0]['checks'] = [{'type': entry['check']}]

    return ret


def parse_pol_output(text):
    text = text.lower()
    text = text.split('[answer]', 1)[-1].strip()
    text = text.split('[done]', 1)[0].strip()

    return json.loads(text)


def create_clinical_note(dialog_state):
    slots_with_canon_heads = [
        'symptom',
        'habit',
        'medical_history',
        'family_history',
        'medication',
        'medical_test',
        'exposure',
        'disease'
    ]
    slots_without_canon_heads = [
        'occupation',
        'travel',
        'basic_information',
        'residence'
    ]
    slot_descriptions = {
        'positive_symptom': 'Positive Symptoms (current symptoms experienced by the patient)',
        'negative_symptom': 'Negative Symptoms (symptoms not experienced by the patient)',
        'unknown_symptom': 'Unknown Symptoms (symptoms whose status is unknown)',
        'positive_habit': 'Positive Habits (habits the patient has)',
        'negative_habit': 'Negative Habits (habits the patient does not have)',
        'unknown_habit': 'Unknown Habits (habits whose status is unknown)',
        'positive_medical_history': 'Positive Past Medical History (previous medical conditions experienced by the patient)',
        'negative_medical_history': 'Negative Past Medical History (medical conditions the patient has not experienced)',
        'unknown_medical_history': 'Unknown Past Medical History (medical conditions whose status is unknown)',
        'positive_family_history': 'Positive Family History (medical conditions present in the patient\'s family)',
        'negative_family_history': 'Negative Family History (medical conditions not present in the patient\'s family)',
        'unknown_family_history': 'Unknown Family History (medical conditions whose status is unknown)',
        'positive_medication': 'Positive Medications (medications the patient is taking)',
        'negative_medication': 'Negative Medications (medications the patient is not taking)',
        'unknown_medication': 'Unknown Medications (medications whose status is unknown)',
        'avail_medical_test': 'Available Medical Tests (medical tests the patient has undergone)',
        'unavail_medical_test': 'Unavailable Medical Tests (medical tests the patient has not undergone)',
        'positive_exposure': 'Positive Exposures (exposures the patient has had)',
        'negative_exposure': 'Negative Exposures (exposures the patient has not had)',
        'unknown_exposure': 'Unknown Exposures (exposures whose status is unknown)',
        'occupation': 'Patient Occupation Details',
        'travel': 'Patient Travel History',
        'basic_information': 'Patient General Information',
        'residence': 'Patient Accommodation Details'
    }

    canon_slots = [x for x in slots_with_canon_heads if x != 'disease']
    func = lambda x: ' '.join([y.capitalize() for y in x.split('_')])
    func2 = lambda x: x.lower().replace('_', ' ')
    lines = []
    for ii, slot in enumerate(dialog_state.keys()):
        # lines.append(f"## {func(slot)}")
        if slot not in slot_descriptions:
            continue
        lines.append(f"## {slot_descriptions[slot]}")
        flag = any(x in slot for x in canon_slots)

        for jj, entry in enumerate(dialog_state[slot]):
            if entry.get('value') in ['other', 'general', 'all']:
                continue

            if flag:
                tmp = f'{jj + 1}. {func2(entry["value"])}'
                lines.append(tmp)

            for k, v in entry.items():
                if flag and k == 'value':
                        continue

                if type(v) == list:
                    lines.append(f'  - {func2(k)}: {", ".join([func2(z) for z in v])}')
                else:
                    lines.append(f'  - {func2(k)}: {func2(v)}')

        lines.append('')

    return '\n'.join(lines).strip()


def merge_single(dialog_state, slot, entry):
    related_slots = []
    old_slots = list(dialog_state)
    if any(x in slot for x in ['positive', 'negative', 'unknown', 'avail', 'unavail']):
        tmp = slot.split('_', 1)[-1]
        related_slots = [x for x in old_slots if (x != slot) and (tmp in x)]

    for tmp in related_slots:
        for ii in range(len(dialog_state[tmp])):
            if dialog_state[tmp][ii]['value'] == entry['value']:
                dialog_state[tmp][ii] = None
        dialog_state[tmp] = [x for x in dialog_state[tmp] if x is not None]

    # If slot is not in dialog state add it and return
    if slot not in dialog_state:
        dialog_state[slot] = [entry]
        return dialog_state

    # Check matches
    match_idx = []
    for ii, oentry in enumerate(dialog_state[slot]):
        if 'value' not in oentry:
            continue
        if oentry['value'] == entry['value']:
            match_idx.append(ii)
    assert len(match_idx) <= 1, f"Multiple matches found for slot {slot} value {entry['value']}."

    # If slot is in dialog state but entry is not add it in return
    if len(match_idx) == 0:
        dialog_state[slot].append(entry)
        return dialog_state
    
    # If slot is in the dialog state and entry is present then update it.
    match_idx = match_idx[0]
    for kk in entry:
        if kk == 'value':
            continue
        if kk not in dialog_state[slot][match_idx]:
            dialog_state[slot][match_idx][kk] = entry[kk]
            continue
        if type(entry[kk]) == list:
            dialog_state[slot][match_idx][kk].extend(entry[kk])
            dialog_state[slot][match_idx][kk] = list(set(dialog_state[slot][match_idx][kk]))
            continue
        dialog_state[slot][match_idx][kk] = entry[kk]

    return dialog_state


def merge_states(old_state, new_state):
    slots_without_canon_heads = [
        'occupation',
        'travel',
        'basic_information',
        'residence'
    ]

    for slot, entries in new_state.items():
        if slot in slots_without_canon_heads:
            if slot not in old_state:
                old_state[slot] = deepcopy(entries)
                continue

            tmp = deepcopy(old_state[slot][0])
            tmp.update(entries[0])
            old_state[slot] = [tmp]
        else:
            for entry in entries:
                if 'value' not in entry:
                    continue
                old_state = merge_single(deepcopy(old_state), slot, deepcopy(entry))

    ret = dict()
    for slot, entries in old_state.items():
        if len(entries) == 0:
            continue
        ret[slot] = []
        for entry in entries:
            tentry = dict()
            for key, value in entry.items():
                if len(value) == 0:
                    continue
                tentry[key] = value
            ret[slot].append(tentry)

    return ret
