import json
import yaml
import torch
from copy import deepcopy
from unsloth import FastLanguageModel
from transformers import GenerationConfig
from utils import merge_states
from transformers import AutoModel, AutoTokenizer


class NLUModel(object):
    def __init__(self, model_name="nlu"):
        with open('./models/nlu.yml', 'r') as fp:
            cfg = yaml.safe_load(fp)
        model_path = './models/nlu/'

        self.cfg = cfg
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_path, max_seq_length=cfg['model']['max_seq_length'],
            dtype=None, load_in_4bit=True
        )
        FastLanguageModel.for_inference(model)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'left'
        gconfig = GenerationConfig(
            num_beams=1, do_sample=False, max_new_tokens=512,
            eos_token_id=tokenizer.convert_tokens_to_ids(tokenizer.eos_token),
            pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        )
        self.model = model
        self.tokenizer = tokenizer
        self.gconfig = gconfig
        print('NLU Model loaded successfully.')

        self.USER_PROMPT = """Given a doctor-patient dialogue and the last doctor action output patient's intent and associated slot-values.

# Doctor-Patient Dialogue
{{dialogue}}

# Last Doctor Action
```json
{{last_doctor_action}}
```"""

    def _parse_nlu_output(self, text):
        text = text.split("```json", 1)[-1].strip()
        text = text.split("```", 1)[0].strip()
        return json.loads(text)

    def predict(self, **kwargs):
        context = kwargs['dialog_history']
        history_size = self.cfg['model']['history_size']
        last_action = kwargs['last_action']

        history = '\n'.join(context)
        user = deepcopy(self.USER_PROMPT)
        user = user.replace('{{dialogue}}', history)
        tmp = json.dumps(last_action)
        user = user.replace('{{last_doctor_action}}', tmp)

        messages = [{'role': 'user', 'content': user}]
        gen_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        gen_prompt += "The answer is"
        inps = self.tokenizer(gen_prompt, return_tensors='pt', add_special_tokens=False)
        input_ids = inps['input_ids'].to(self.model.device)
        attn_mask = inps['attention_mask'].to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=input_ids, attention_mask=attn_mask, generation_config=self.gconfig,
                use_cache=True, stop_strings=["<|end_header_id|>"], tokenizer=self.tokenizer,
            )
        outputs = outputs.to('cpu')
        input_end = input_ids.size(1)
        outputs = outputs[:, input_end:]
        response = self.tokenizer.batch_decode(outputs)[0]

        try:
            nlu = self._parse_nlu_output(response)
        except Exception as e:
            print(f"Error parsing NLU output: {e}")
            print(f"Prediction: {response}")
            nlu = {}

        new_dst = deepcopy(kwargs['dialog_state'])
        for entry in nlu:
            if 'slots' in entry:
                new_dst = merge_states(new_dst, entry['slots'])

        return nlu, new_dst


class NLGModel(object):
    def __init__(self, model_name):
        with open('./models/nlg.yml', 'r') as fp:
            cfg = yaml.safe_load(fp)
        model_path = './models/nlg/'


        self.cfg = cfg
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_path, max_seq_length=cfg['model']['max_seq_length'],
            dtype=None, load_in_4bit=True
        )
        FastLanguageModel.for_inference(model)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'left'
        gconfig = GenerationConfig(
            num_beams=1, do_sample=False, max_new_tokens=1024,
            eos_token_id=tokenizer.convert_tokens_to_ids(tokenizer.eos_token),
            pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        )
        self.model = model
        self.tokenizer = tokenizer
        self.gconfig = gconfig
        print('Models loaded successfully.')

        self.USER_PROMPT = """Given a doctor-patient dialogue and the doctor action output doctor's response.

# Doctor-Patient Dialogue
{{dialogue}}

# Doctor Action
```json
{{doctor_action}}
```"""

    def _parse_nlg_output(self, text):
        text = text.split('</answer>', 1)[0].strip()
        return text

    def predict(self, **kwargs):
        context = kwargs['dialog_history']
        history_size = self.cfg['model']['history_size']
        actions = kwargs['actions']

        history = '\n'.join(context)
        user = deepcopy(self.USER_PROMPT)
        user = user.replace('{{dialogue}}', history)
        tmp = json.dumps(actions)
        user = user.replace('{{doctor_action}}', tmp)

        messages = [{'role': 'user', 'content': user}]
        gen_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        gen_prompt += "The answer is\n<answer>"
        inps = self.tokenizer(gen_prompt, return_tensors='pt', add_special_tokens=False)
        input_ids = inps['input_ids'].to(self.model.device)
        attn_mask = inps['attention_mask'].to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=input_ids, attention_mask=attn_mask, generation_config=self.gconfig,
                use_cache=True, stop_strings=["<|end_header_id|>"], tokenizer=self.tokenizer,
            )
        outputs = outputs.to('cpu')
        input_end = input_ids.size(1)
        outputs = outputs[:, input_end:]
        response = self.tokenizer.batch_decode(outputs)[0]
        response = self._parse_nlg_output(response)

        return response


class RAGNet(object):
    # BAAI/bge-m3
    def __init__(self):
        structure_path = "./models/bn/structure.json"
        with open(structure_path, 'r') as fp:
            structure = json.load(fp)

        print("Number of diseases:", len(structure))
        self.diseases = sorted(structure)
        self.structure = structure
        self.documents = [
            f"Disease: {dis}\n\n{structure[dis]['text']}" for dis in self.diseases
        ]
        self.model = AutoModel.from_pretrained('BAAI/bge-m3').to('cuda')
        self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
        self.document_embeddings = self.get_embeddings(self.documents)

    def get_embeddings(self, texts, batch_size=32):
        """Generate embeddings for multiple texts efficiently using GPU."""
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            # Add prefix for BGE models
            batch_texts = [f"Represent this sentence for searching relevant passages: {t}" for t in batch_texts]
            # Tokenize
            inputs = self.tokenizer(
                batch_texts, return_tensors="pt", padding=True, 
                truncation=True, max_length=8192
            ).to(self.model.device)
            # Generate embeddings
            with torch.no_grad():
                outputs = self.model(**inputs)
                batch_embeddings = outputs.last_hidden_state[:, 0, :]  # [CLS] token
                batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
            embeddings.append(batch_embeddings)

        return torch.cat(embeddings, dim=0)

    def get_posterior(self, dst, dialog_history=None):
        query = "[dialog_state]\n" + json.dumps(dst)
        query = query + "\n\n[dialog_history]\n" + dialog_history
        query_embedding = self.get_embeddings([query])
        similarities = torch.matmul(query_embedding, self.document_embeddings.T).squeeze(0)
        top_k_values, top_k_indices = torch.topk(similarities, k=len(self.documents))
        # Move results to CPU for final processing
        top_k_values = top_k_values.cpu().numpy()
        top_k_indices = top_k_indices.cpu().numpy()

        tmp = []
        for jj, idx in enumerate(top_k_indices):
            tmp.append({
                'disease': self.diseases[idx],
                'score': top_k_values[jj],
                'text': self.structure[self.diseases[idx]]['text']
            })

        return deepcopy(tmp[:5])


class POLModel(object):
    def __init__(self, model_name):
        assert model_name is not None, "Model name is required."

        if model_name == 'llama3':
            cfg_path = './models/llama3.yml'
            model_path = './models/llama3/'
        elif model_name == 'bayesnet':
            cfg_path = './models/bayesnet.yml'
            model_path = './models/bayesnet/'
        elif model_name == 'rag':
            cfg_path = './models/rag.yml'
            model_path = './models/rag/'
        else:
            raise ValueError(f"Unknown model name: {model_name}")

        with open(cfg_path, 'r') as fp:
            self.cfg = yaml.safe_load(fp)

        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_path, max_seq_length=self.cfg['model']['max_seq_length'],
            dtype=None, load_in_4bit=True
        )
        FastLanguageModel.for_inference(model)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'left'
        gconfig = GenerationConfig(
            num_beams=1, do_sample=False, max_new_tokens=1024,
            eos_token_id=tokenizer.convert_tokens_to_ids(tokenizer.eos_token),
            pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        )
        self.model = model
        self.tokenizer = tokenizer
        self.gconfig = gconfig
        print('POL Model loaded successfully.')

        if self.cfg['model'].get('use_bayes_net', False):
            print('Loading Bayes Net....')
            from bayes_net import OneDiseaseBayesNet
            self.bayesnet = OneDiseaseBayesNet(self.cfg['bayes_net'])
            self.USER_PROMPT = """Given a doctor-patient dialogue and patient clinical report output the doctor's action as a continuation of the dialogue.
To help in this task, a list of possible diseases and their symptoms is also provided.

# Doctor-Patient Dialogue
{{dialogue}}

# Patient Clinical Report
{{dialogue_state}}

# Possible Diseases and Symptoms
{{hints}}"""

        elif self.cfg['model'].get('use_rag', False):
            print('Loading RAG....')
            self.bayesnet = RAGNet()
            self.USER_PROMPT = """Given a doctor-patient dialogue and patient clinical report output the doctor's action as a continuation of the dialogue.
To help in this task, a list of possible diseases and their symptoms is also provided.

# Doctor-Patient Dialogue
{{dialogue}}

# Patient Clinical Report
{{dialogue_state}}

# Possible Diseases and Symptoms
{{hints}}"""

        else:
            self.bayesnet = None
            self.USER_PROMPT = """Given a doctor-patient dialogue and patient clinical report output the doctor's action as a continuation of the dialogue.

# Doctor-Patient Dialogue
{{dialogue}}

# Patient Clinical Report
{{dialogue_state}}"""

    def create_posterior_note(self, posterior, desc_type):
        content = []
        if desc_type == 'text':
            template = """## {{disease}} (Likelihood - {{score}})
{{text}}
"""
        elif desc_type == 'rag':
            template = """## {{disease}}
{{text}}
"""        
        else:
            template = """## {{disease}} (Likelihood - {{score}})"""

        for entry in posterior:
            disease = entry['disease']
            score = round(100.0 * entry['score'], 2)
            text = entry['text']
            tmp = deepcopy(template)
            tmp = tmp.replace('{{disease}}', disease.capitalize())
            tmp = tmp.replace('{{score}}', str(score))
            if desc_type in ['text', 'rag']:
                tmp = tmp.replace('{{text}}', text)
            content.append(tmp)

        return '\n'.join(content)

    def _parse_pol_output(self, text):
        text = text.split("```json", 1)[-1].strip()
        text = text.split("```", 1)[0].strip()
        return json.loads(text)

    def predict(self, **kwargs):
        dialog_history = kwargs['dialog_history']
        dialog_state = kwargs['dialog_state']
        history_size = self.cfg['model']['history_size']
    
        clen = len(dialog_history)
        st = max(clen - history_size, 0)
        if history_size == -1:
            st = 0
        en = len(dialog_history)

        history = '\n'.join(dialog_history[st:en])
        clinical_note = "```json\n" + json.dumps(dialog_state) + "\n```"
        user = deepcopy(self.USER_PROMPT)
        user = user.replace('{{dialogue}}', history)
        user = user.replace('{{dialogue_state}}', clinical_note)

        if self.bayesnet is not None:
            posterior = self.bayesnet.get_posterior(dialog_state, dialog_history=history)
            hints = self.create_posterior_note(posterior, self.cfg['model'].get('desc_type', 'text'))
            user = user.replace('{{hints}}', hints)

        messages = [{'role': 'user', 'content': user}]
        gen_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        gen_prompt += "The answer is"
        inps = self.tokenizer(gen_prompt, return_tensors='pt', add_special_tokens=False)
        input_ids = inps['input_ids'].to(self.model.device)
        attn_mask = inps['attention_mask'].to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=input_ids, attention_mask=attn_mask, generation_config=self.gconfig,
                use_cache=True, stop_strings=["<|end_header_id|>"], tokenizer=self.tokenizer,
            )
        outputs = outputs.to('cpu')
        input_end = input_ids.size(1)
        outputs = outputs[:, input_end:]
        response = self.tokenizer.batch_decode(outputs)[0]

        try:
            actions = self._parse_pol_output(response)
        except Exception as e:
            print(f"Error parsing POL output: {e}")
            print(f"Prediction: {response}")
            actions = []

        actions = [{"action": "chit-chat"}] if len(actions) == 0 else actions

        if kwargs.get('return_gen', False):
            return actions, response

        return actions
