import os
import re
import uuid
import json
from copy import deepcopy
from fastapi import HTTPException
from typing import Dict, List
from pydantic import BaseModel
from utils import search_dst_for_actions, merge_states, create_clinical_note, parse_output_to_nlu  
from llm_clients import get_llm_client
import logging

from prompts import (
    DIAGNOSIS_RESPONSE_PROMPT, SALUTATIONS_RESPONSE_PROMPT, CHIT_CHAT_RESPONSE_PROMPT, INQUIRY_RESPONSE_PROMPT,
    GENERAL_RESPONSE_PROMPT, ALL_RESPONSE_PROMPT, NLU_PROMPT
)


# Configure the logger
logging.basicConfig(
    filename='patient_logs.log',  # File to write logs to
    level=logging.INFO,    # Minimum log level to capture
    format='%(asctime)s - %(levelname)s - %(message)s'  # Log format
)
client, MODEL = get_llm_client()
print(f'Using {MODEL} for completion.')


class CreateSessionRequest(BaseModel):
    did: str
    mode: str


class ExecuteActionsRequest(BaseModel):
    doctor_uttr: str  # Doctor's utterance
    actions: List[Dict]  # Doctor's actions


# -------------------- Patient Service -------------------- #
class Patient:
    def __init__(self, dialog, dst_key=None, ddx_key=None):
        utterances = dialog["utterances"]
        if ddx_key is not None:
            self.ddx = dialog[ddx_key]
        else:
            self.ddx = dialog['differential']
        self.cutoff = 100  # number of turns

        ref_dialog = []
        first = None
        last_dialog_state = None
        first_dialog_state = None
        for uid, uttr in enumerate(utterances):
            ref_dialog.append(f"{uttr['speaker'].capitalize()}: {uttr['text']}")

            if uttr['speaker'] != 'patient':
                continue

            if first is None and len(uttr['dialog_state']) > 0:
                first = uid
                first_dialog_state = deepcopy(uttr['dialog_state'])

            last_dialog_state = uttr['dialog_state']

        if dst_key is not None:
            print(f"Using {dst_key} field for patient profile.")
            last_dialog_state = dialog[dst_key]

        self.first_dialog_state = first_dialog_state
        self.last_dialog_state = deepcopy(last_dialog_state)
        self.dialog_state = deepcopy(last_dialog_state)
        self.ref_dialog = ref_dialog
        self.clinical_note = create_clinical_note(last_dialog_state)
        self.diffferential_diagnosis = '\n'.join([x.strip() for x in dialog['differential'] if len(x.strip()) > 0])

        # On-going dialogue history primed till first patient utterance.
        # Dialogue history always ends with a patient utterance.
        self.ongoing_dialog = deepcopy(ref_dialog[:first + 1])
        self.ongoing_dialog_orig = deepcopy(ref_dialog[:first + 1])

        # Latest doctor, patient response
        self.latest_patient_uttr = None
        self.latest_doctor_uttr = None

        self.covered_findings = deepcopy(first_dialog_state)

        self.logs = [{
            "clinical_note": self.clinical_note,
            "ddx": '\n'.join(self.ddx),
            "running_dialogue_state": deepcopy(self.covered_findings),
            "llm_used": MODEL,
        }]
        self.handle_system = 'gemma' in MODEL
        print(f"System handling: {self.handle_system}")

    def reset(self):
        self.ongoing_dialog = []
        self.covered_findings = deepcopy(self.first_dialog_state)
        self.cutoff = 62

    def _remove_unknown(self, state, remove_slot=False):
        tmp = deepcopy(state)
        slots = list(tmp.keys())
        for slot in slots:
            if 'unknown' in slot and remove_slot:
                del tmp[slot]
                continue

            for entry in tmp[slot]:
                keys = list(entry)
                for key in keys:
                    if type(entry[key]) == list:
                        if len(entry[key]) == 0:
                            del entry[key]
                        elif len(entry[key]) == 1 and 'unknown' in entry[key]:
                            del entry[key]
        return tmp

    def _post_process(self, text):
        # Try with considerating `` first
        pattern = r"<answer>(?!`.*`</answer>)\s*(.*?)\s*</answer>"
        matches = re.findall(pattern, text, flags=re.DOTALL)

        if len(matches) > 0:
            return matches[-1].strip().strip('"')
        
        # Try without considering quotes now.
        pattern = r"<answer>\s*(.*?)\s*</answer>"
        matches = re.findall(pattern, text, flags=re.DOTALL)
        matches = [m for m in matches if '` and `' != m]
        if len(matches) > 0:
            match = max(matches, key=lambda x: len(x))
            match = match.strip("`").strip().strip('"')
            return match

        if len(matches) > 0:
            return matches[-1].strip().strip('"')

        lines = text.split('\n')
        lines = [x for x in lines if "<answer>" in x]
        if len(lines) == 0:
            logging.info(f"Error in processing output: {text}")
            return "I don't know."

        if len(lines) == 1:
            output = lines[-1].split('<answer>', 1)[-1].split('</answer>', 1)[0].strip().strip('"')
            return output

        outlines = [x for x in lines if "**output**" in x]
        if len(outlines) > 0:
            output = outlines[-1].split('<answer>', 1)[-1].split('</answer>', 1)[0].strip().strip('"')
            return output

        # We are not sure what happened. Falling to default
        output = lines[0].split('<answer>', 1)[-1].split('</answer>', 1)[0].strip().strip('"')

        return output

    def _normalize_inquire_actions(self, actions):
        # [{'action': ACTION, "symptom": [{"value": "headache"}, {"value": "coughing"}]}]
        # --> [{'action': ACTION, "symptom": [{"value": "headache"}]},
        # --> {'action': ACTION, "symptom": [{"value": "coughing"}]}]
        new_actions = []
        for action in actions:
            assert action['action'] == 'inquire'
            for slot in action:
                if slot == 'action':
                    continue
                for entry in action[slot]:
                    new_actions.append({
                        'action': 'inquire',
                        slot: deepcopy([entry]),
                    })
        return new_actions

    def _compile_diagnose_response(self, actions, doctor_uttr):
        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"

        prompt = deepcopy(DIAGNOSIS_RESPONSE_PROMPT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=128, n=1
        )
        response = completion[0]
        patient_uttr = self._post_process(response)

        return patient_uttr

    def _compile_salutation_response(self, actions, doctor_uttr):
        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"

        prompt = deepcopy(SALUTATIONS_RESPONSE_PROMPT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=128, n=1
        )
        response = completion[0]
        patient_uttr = self._post_process(response)

        return patient_uttr

    def _compile_chitchat_response(self, actions, doctor_uttr):
        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"

        prompt = deepcopy(CHIT_CHAT_RESPONSE_PROMPT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=128, n=1
        )
        response = completion[0]
        patient_uttr = self._post_process(response)

        return patient_uttr

    def _compile_all_response(self, SLOT, query_item, doctor_uttr):
        DST_KEY = f"positive_{SLOT}"
        if SLOT == 'medical_test':
            DST_KEY = f"avail_{SLOT}"
        queries = [
            x.get('value')
            for x in self.covered_findings.get(DST_KEY, [])
        ]
        queries = [x for x in queries if x is not None]
        matches = []
        for entry in self.dialog_state.get(DST_KEY, []):
            if entry.get('value') in queries:
                matches.append(deepcopy(entry))

        if SLOT == "medical_history":
            matches = [mm for mm in matches if mm.get('value') != 'immunization']

        answer = {DST_KEY: matches}
        answer_sketch = json.dumps(answer)

        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"

        prompt = deepcopy(ALL_RESPONSE_PROMPT)
        prompt = prompt.replace("{{slot}}", SLOT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        prompt = prompt.replace("{{answer_sketch}}", answer_sketch)

        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=1024, n=1
        )
        response = completion[0]
        patient_uttr = self._post_process(response)
        # with open('nlu.txt', 'a') as f:
        #     f.write(f"Doctor: {doctor_uttr}\n")
        #     f.write(f"Patient: {patient_uttr}\n")
        #     f.write(f"Slot: {SLOT}\n")
        #     f.write(f"Answer: {answer}\n")
        #     f.write("\n")
        answer = self.get_nlu(doctor_uttr, patient_uttr, SLOT)
        self.covered_findings = merge_states(self.covered_findings, answer)
        return patient_uttr

    def _compile_other_response(self, SLOT, query_item, doctor_uttr):
        DST_KEY = f"positive_{SLOT}"
        if SLOT == 'medical_test':
            DST_KEY = f"avail_{SLOT}"
        queries = [
            x.get('value')
            for x in self.covered_findings.get(DST_KEY, [])
        ]
        queries = [x for x in queries if x is not None]
        matches = []
        for entry in self.dialog_state.get(DST_KEY, []):
            if entry.get('value') not in queries:
                matches.append(deepcopy(entry))

        if SLOT == "medical_history":
            matches = [mm for mm in matches if mm.get('value') != 'immunization']

        answer = {DST_KEY: matches}
        answer_sketch = json.dumps(answer)

        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"

        prompt = deepcopy(ALL_RESPONSE_PROMPT)
        prompt = prompt.replace("{{slot}}", SLOT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        prompt = prompt.replace("{{answer_sketch}}", answer_sketch)

        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=1024, n=1
        )
        response = completion[0]
        patient_uttr = self._post_process(response)
        # with open('nlu.txt', 'a') as f:
        #     f.write(f"Doctor: {doctor_uttr}\n")
        #     f.write(f"Patient: {patient_uttr}\n")
        #     f.write(f"Slot: {SLOT}\n")
        #     f.write(f"Answer: {answer}\n")
        #     f.write("\n")
        answer = self.get_nlu(doctor_uttr, patient_uttr, SLOT)
        self.covered_findings = merge_states(self.covered_findings, answer)

        return patient_uttr

    def get_nlu(self, last_doctor_uttr, patient_response, SLOT):
        with open("nlu.txt", "a") as f:
            f.write("-"*50 + "\n")  
            f.write(f"Last Doctor Utterance: {last_doctor_uttr}\n")
            f.write(f"Patient Response: {patient_response}\n")
            f.write(f"Slot: {SLOT}\n")
            f.write("-"*50 + "\n")

        prompt = deepcopy(NLU_PROMPT)
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        prompt = prompt.replace("{{patient_response}}", patient_response)
        prompt = prompt.replace("{{slot}}", SLOT)
        messages=[{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=128, n=1
        )
        response = completion[0]
        parsed_output = parse_output_to_nlu(response)
        result = {
            "input": prompt,
            "output": response,
            "parsed_output": parsed_output,
        }
        with open("nlu.txt", "a") as f:
            f.write(json.dumps(result, indent=4) + "\n")
            f.write("-"*50 + "\n")  # Separator for each entry
        return parsed_output

    def _compile_general_response(self, SLOT, query_item, doctor_uttr):
        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"

        prompt = deepcopy(GENERAL_RESPONSE_PROMPT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(
            model=MODEL, messages=messages, temperature=0.0, max_tokens=128, n=1
        )
        response = completion[0]
        patient_uttr = self._post_process(response)

        return patient_uttr

    def _compile_inquire_response(self, answers, doctor_uttr):
        dialog_history = deepcopy(self.ongoing_dialog)
        last_doctor_uttr = f"Doctor: {doctor_uttr}"
        answer_sketch = json.dumps(answers)

        prompt = deepcopy(INQUIRY_RESPONSE_PROMPT)
        prompt = prompt.replace("{{dialogue_history}}", "\n".join(dialog_history))
        prompt = prompt.replace("{{last_doctor_uttr}}", last_doctor_uttr)
        prompt = prompt.replace("{{answer_sketch}}", answer_sketch)
        messages = [{"role": "user", "content": prompt}]
        completion = client.chat_completion(model=MODEL, messages=messages, temperature=0.0, max_tokens=1024, n=1)
        response = completion[0]
        patient_uttr = self._post_process(response)
        self.covered_findings = merge_states(self.covered_findings, answers)

        return patient_uttr

    def get_action_answer(self, actions, doctor_uttr):
        salutation_actions = [act for act in actions if act['action'] == 'nod_prompt_salutations']
        chitchat_actions = [act for act in actions if act['action'] == 'chit-chat']
        diagnosis_actions = [act for act in actions if act['action'] == 'diagnosis']
        other_actions = [act for act in actions if act['action'] == 'other']

        if len(diagnosis_actions) > 0:
            logging.info("Diagnosis action received.")
            return self._compile_diagnose_response(diagnosis_actions, doctor_uttr)
        elif len(salutation_actions) > 0:
            logging.info("Salutation action received.")
            return self._compile_salutation_response(salutation_actions, doctor_uttr)
        elif len(chitchat_actions) > 0:
            logging.info("Chit-chat action received.")
            return self._compile_chitchat_response(chitchat_actions, doctor_uttr)
        elif len(other_actions) > 0:
            logging.info("Other action received.")
            return self._compile_chitchat_response(chitchat_actions, doctor_uttr)

        inquire_actions = [act for act in actions if act['action'] == 'inquire']
        assert len(inquire_actions) > 0, f"At least one inquire action is expected\n{json.dumps(actions)}"
        inquire_actions_norm = self._normalize_inquire_actions(inquire_actions)
        answers = []
        for action in inquire_actions_norm:
            SLOT = [x for x in action if x != 'action'][0]
            query_item = action[SLOT][0]
            query_value = query_item.get('value')

            if query_value is None:
                continue
            elif query_value == 'all':
                logging.info(f"All action received. Slot: {SLOT}")
                return self._compile_all_response(SLOT, query_item, doctor_uttr)
            elif query_value == 'other':
                logging.info(f"Other action received. SLOT: {SLOT}")
                return self._compile_other_response(SLOT, query_item, doctor_uttr)
            elif query_value == 'general':
                logging.info(f"General action received. SLOT: {SLOT}")
                return self._compile_general_response(SLOT, query_item, doctor_uttr)
            answers.append(search_dst_for_actions(self.dialog_state, SLOT, query_item))
        assert len(answers) > 0, f"At least one answer is expected\n{actions}"

        combined_answer = {}
        for answer in answers:
            combined_answer = merge_states(combined_answer, answer)
        logging.info(f"Inquire action received. Answer: {combined_answer}")

        return self._compile_inquire_response(combined_answer, doctor_uttr)

    def execute_and_commit(self, actions: List[Dict], doctor_uttr: str) -> str:
        if self.cutoff == 0:
            return "Session has ended", "terminated"

        logging.info(f"Received Action: {actions}")
        patient_uttr = self.get_action_answer(actions, doctor_uttr)
        logging.info(f"Doctor's Action: {actions}")
        logging.info(f"Doctor's utterance: {doctor_uttr}")
        logging.info(f"Patient' utterance: {patient_uttr}")

        self.ongoing_dialog.append(f"Doctor: {doctor_uttr}")
        self.ongoing_dialog.append(f"Patient: {patient_uttr}")
        self.cutoff -= 1

        log = {
            "doctor_action": actions,
            "doctor_uttr": doctor_uttr,
            "patient_uttr": patient_uttr,
            "updated_dialogue_state": deepcopy(self.covered_findings),
        }
        self.logs.append(log)

        return patient_uttr, "success" if self.cutoff != 0 else "terminated"

    def get_dialog(self):
        return self.ongoing_dialog

    def get_covered_state(self):
        tmp = self._remove_unknown(self.covered_findings)

        return tmp

    def get_logs(self):
        return deepcopy(self.logs)

# Session Manager for PatientService
class PatientService:
    def __init__(self, dialog_path, dst_key=None, ddx_key=None):
        self.sessions = {}
        with open(dialog_path, 'r') as fp:
            self.dialogs = json.load(fp)
        self.dst_key = dst_key
        self.ddx_key = ddx_key

    def create_session(self, did, mode='actions'):
        """
        Let's allow the client to dictate what dids they would like to use for simulation.
        """
        print(did)
        session_id = str(uuid.uuid4())
        if mode == 'actions':
            self.sessions[session_id] = Patient(self.dialogs[did], self.dst_key, self.ddx_key)
        else:
            raise ValueError(f"Invalid mode: {mode}")

        logging.info(f"Session created {session_id}.")
        return session_id

    def get_patient(self, session_id):
        if session_id not in self.sessions:
            raise HTTPException(status_code=404, detail="Session not found")
        return self.sessions[session_id]

    def close(self, session_id):
        if session_id in self.sessions:
            del self.sessions[session_id]
            logging.info(f"Session deleted {session_id}.")
