import json
import numpy as np
from copy import deepcopy


def dst_to_symptoms(dst, findings=None):
    """
    Extract nodes from the dialog state.
    """
    positive_findings = []
    negative_findings = []
    
    for slot in dst:
        if slot not in ['positive_symptom', 'negative_symptom']:
            continue
        status = slot.split('_', 1)[0]

        for entry in dst[slot]:
            if 'value' not in entry:
                continue
            value = entry['value']
            if value in ['all', 'past experience', 'other', 'general', 'symptom speedup']:
                continue

            tmp = []
            if value == 'pain':
                locations = entry.get('location', [])
                if len(locations) > 0:
                    for ll in locations:
                        tmp.append(f"{value} # {ll}")
                else:
                    tmp.append("pain")

            elif value == 'back pain':
                # back, upper back structure
                tmp.append("pain # back")

            elif value == 'bone pain':
                tmp.append("pain # skeletal bone")

            elif value == 'chest pain at rest':
                tmp.append("chest pain")

            elif value == 'facial pain':
                tmp.append("pain # face")

            elif value == 'head ache':
                tmp.append("headache")

            elif value == 'neck pain':
                tmp.append("pain # neck")

            elif value == 'swelling':
                locations = entry.get('location', [])
                if len(locations) > 0:
                    for ll in locations:
                        tmp.append(f"{value} # {ll}")
                else:
                    tmp.append("swelling")

            elif value == 'swelling of eyelid':
                tmp.append("swelling # eye lids")

            elif value == 'erythema':
                locations = entry.get('location', [])
                if len(locations) > 0:
                    for ll in locations:
                        tmp.append(f"{value} # {ll}")
                else:
                    tmp.append("erythema")

            else:
                tmp.append(value)

            if status == 'positive':
                positive_findings.extend(tmp)
            else:
                negative_findings.extend(tmp)

    if findings is not None:
        positive_findings = sorted(set([x for x in positive_findings if x in findings]))
        negative_findings = sorted(set([x for x in negative_findings if x in findings]))
    else:
        positive_findings = sorted(set(positive_findings))
        negative_findings = sorted(set(negative_findings))

    return positive_findings, negative_findings


class OneDiseaseBayesNet(object):
    def __init__(self, config):
        print()
        print('BayesNet Config', config)
        print()
        with open(config['structure_path'], 'r') as fp:
            structure = json.load(fp)

        with open(config['disease_priors_path'], 'r') as fp:
            disease_priors = json.load(fp)

        with open(config['disease_finding_probs_path'], 'r') as fp:
            disease_finding_probs = json.load(fp)

        self.leak_prob = config['leak_prob']
        self.mode = config['mode']
        self.topk = config['topk']

        self.p_dis = dict()
        self.log_p_dis = dict()
        self.minus_log_p_dis = dict()
        for dis in disease_priors:
            self.p_dis[dis] = disease_priors[dis]
            self.log_p_dis[dis] = np.log(disease_priors[dis])
            self.minus_log_p_dis[dis] = np.log(1.0 - disease_priors[dis])

        self.all_symptoms = sorted(set([ss for dd in disease_finding_probs for ss in disease_finding_probs[dd]]))
        self.structure = structure
        self.diseases = list(disease_priors)
        self.likelihoods = deepcopy(disease_finding_probs)
        print(f'Number of diseases {len(self.diseases)}')
        self.pos_term, self.neg_term = self.build_bn(structure, self.all_symptoms, disease_finding_probs)

    def build_bn(self, structure, symptom_list, disease_finding_probs):
        """
        Building Bayes Net considering all symptoms.
        """
        print('Building combined BN')
        pos_term = dict()
        neg_term = dict()
        for dis in self.diseases:
            pos_term[dis] = dict()
            neg_term[dis] = dict()

            for symp in symptom_list:
                if symp in structure[dis]['symptoms']:
                    prob = disease_finding_probs[dis][symp]
                else:
                    prob = 0.0

                tt = (1.0 - self.leak_prob) * (1.0 - prob)
                if tt == 0:
                    pos_term[dis][symp] = 0.0
                    neg_term[dis][symp] = -99999.0
                    continue

                pos_term[dis][symp] = np.log(1.0 - tt)
                neg_term[dis][symp] = np.log(tt)

        return pos_term, neg_term

    def get_pos_neg_findings(self, dst):
        # positive_findings, negative_findings = dst_to_findings(dst, self.all_symptoms, 'symptom')
        positive_findings, negative_findings = dst_to_symptoms(dst, self.all_symptoms)

        return set(positive_findings), set(negative_findings)

    def get_posterior(self, dst, dialog_history=None):
        positive_findings, negative_findings = self.get_pos_neg_findings(dst)

        # Step 1
        # Compute p(dk=1 dk-=0|f+f-) \prop P(f+f-|dk=1dk-=0)P(d) = P(f+|dk=1dk-=0)P(f-|dk=1dk-=0)
        A = dict()
        comb_prob = sum(self.minus_log_p_dis.values())
        for dis in self.diseases:
            logprob = 0.0
            for pos in positive_findings:
                logprob += self.pos_term[dis][pos]

            for neg in negative_findings:
                logprob += self.neg_term[dis][neg]
            logprob += (comb_prob - self.minus_log_p_dis[dis] + self.log_p_dis[dis])
            A[dis] = logprob

        # Step 2: Compute P(f+f-)
        B = np.log(np.sum([np.exp(A[dis]) for dis in self.diseases]))

        # Step 3: Compute posteriors
        posterior = dict()
        for dis in self.diseases:
            posterior[dis] = np.exp(A[dis] - B)

        sorted_posteriors = sorted(posterior.items(), key=lambda x: x[1], reverse=True)
        sorted_posteriors = sorted_posteriors[:self.topk]

        tmp = []
        for dis, prob in sorted_posteriors:
            tmp.append({'disease': dis, 'score': prob, 'text': self.structure[dis]['text']})

        return deepcopy(tmp)
