import random
from helper import get_response
import ast

def parse_atomic_facts(sample):
    atomic_facts = sample.get('atomic_facts', None)

    if atomic_facts is None:
        return None

    if atomic_facts is None:
        return None

    if isinstance(atomic_facts, list):
        return atomic_facts

    if isinstance(atomic_facts, str):
        if atomic_facts.strip().startswith('[') and atomic_facts.strip().endswith(']'):
            try:
                import ast
                parsed_list = ast.literal_eval(atomic_facts)

                if isinstance(parsed_list, list):
                    return parsed_list
            except (SyntaxError, ValueError):
                try:
                    clean_str = atomic_facts.strip()[1:-1]

                    items = []
                    current_item = ""
                    in_quotes = False

                    for char in clean_str:
                        if char == "'" or char == '"':
                            in_quotes = not in_quotes

                        if char == ',' and not in_quotes and current_item:
                            items.append(current_item.strip().strip("'\""))
                            current_item = ""
                        else:
                            current_item += char

                    if current_item:
                        items.append(current_item.strip().strip("'\""))

                    return items
                except Exception as e:
                    print(f"Failed to parse string list: {e}")

    return atomic_facts


class Patient:
    def __init__(self, args, sample):
        # Assuming 'context' is a list or a long string of historical or background information
        if isinstance(sample['context'], list) and len(sample['context']) > 0:
            if 'initial_info' in sample: self.initial_info = sample['initial_info']
            else: self.initial_info = sample['context'][0]  # Taking the first item if it's a list
            self.context_list = sample['context']
            self.context_para = " ".join(sample['context'])
        elif isinstance(sample['context'], str):
            if sample['context'].strip().startswith('[') and sample['context'].strip().endswith(']'):
                try:
                    # Try to convert the string representation of a list into an actual list
                    context_list = ast.literal_eval(sample['context'])
                    if isinstance(context_list, list):
                        # Process as a list
                        if 'initial_info' in sample:
                            self.initial_info = sample['initial_info']
                        else:
                            self.initial_info = context_list[0]
                        self.context_list = context_list
                        self.context_para = ' '.join(context_list)
                except Exception as e:
                    # If conversion fails, proceed with the original string handling
                    print(f'parse string error: {e}')
            # Assuming sentences are separated by periods, taking the first sentence
            else:
                if 'initial_info' in sample:
                    self.initial_info = sample['initial_info']
                else:
                    self.initial_info = sample['context'].split(". ")[0]
                temp = sample['context'].split(". ")
                self.context_list = [temp[i]+'.' if i!=len(temp)-1 and not temp[i].endswith('.') else temp[i] for i in range(len(temp))]
                self.context_para = sample['context']
        else:
            if 'initial_info' in sample: self.initial_info = sample['initial_info']
            else: self.initial_info = ""  # Default fallback
            self.context_list = []
            self.context_para = 'None'

        self.model_name = args.patient_model
        self.history = []
        self.facts = parse_atomic_facts(sample)

        self.max_length = 50  # Maximum length of the response (different from the expert system)
        self.use_vllm = args.use_vllm
        self.use_api = args.use_api  # Use an API to generate responses

    def update_state(self, question, answer):
        self.history.append({"question": question, "answer": answer})

    def get_state(self):
        return {
            "initial_info": self.initial_info,
            "interaction_history": self.history
        }
    
    def get_questions(self):
        # Return the list of questions asked so far
        return [qa["question"] for qa in self.history]
    
    def get_answers(self):
        # Return the list of answers provided so far
        return [qa["answer"] for qa in self.history]
    
    def get_response(self, messages, max_length=None):
        if max_length is None: max_length = self.max_length
        return get_response(messages, self.model_name, use_vllm=self.use_vllm, use_api=self.use_api, max_length=max_length)
    
    def respond(self, question):
        raise NotImplementedError
    

class RandomPatient(Patient):
    def respond(self, question):
        # Randomly select a response mode
        if random.random() < 0.5 or len(self.context_list) == 0:
            answer = "The patient cannot answer this question, please do not ask this question again."
        else:
            answer = random.choice(self.context_list)
        self.update_state(question, answer)
        return answer

class DirectPatient(Patient):
    def respond(self, question):
        system_prompt = "Answer the question with the given context."
        user_prompt = f"Context: \"{self.initial_info}\"\nQuestion: \"{question}\"\n"
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response, log_probs, num_tokens = self.get_response(messages)
        self.update_state(question, response)
        return response

class InstructPatient(Patient):
    def respond(self, question):
        system_prompt = "You are a truthful assistant that understands the patient's information, and you are trying to answer questions from a medical doctor about the patient."
        user_prompt = f"Below is a context paragraph describing the patient and their conditions:\n\"{self.context_para}\"\nQuestion from the doctor: \"{question}\"\nUse the context paragraph to answer the doctor's question. If the paragraph does not answer the question, simply say \"The patient cannot answer this question, please do not ask this question again.\" Answer only what the question asks for. Do not provide any analysis, inference, or implications. Respond with a straightforward answer to the question ONLY and NOTHING ELSE."
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response, log_probs, num_tokens = self.get_response(messages)
        self.update_state(question, response)
        return response
    
class FactSelectPatient(Patient):
    def respond(self, question):
        if not self.facts:
            # Decompose context into facts if not already done
            system_prompt = "You are a truthful medical assistant that understands the patient's information."
            user_prompt = f"Break the following patient information into a list of independent atomic facts, with one piece of information in each statement. Each fact should only include the smallest unit of information, but should be self-contained.\n\"{self.context_para}\"\nResponse with the list of atomic facts and nothing else, prepend each fact by an index starting from 1. No sub-list allowed."
            messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
            response_text, log_probs, num_tokens = self.get_response(messages, max_length=1000)
            response_text = [s.strip() for s in response_text.splitlines()]
            self.facts = response_text
        
        facts_prompt = "\n".join(self.facts)
        system_prompt = "You are a truthful medical assistant that understands the patient's information, and you are trying to answer questions from a medical doctor about the patient given a list of factual statements describing the patient. Please return the facts that answer the doctor's question verbatim. If none of the facts answer to the question, simply say \"The patient cannot answer this question, please do not ask this question again.\""
        prompt = f"List of facts:\n{facts_prompt}\n\nQuestion from the doctor: \"{question}\"\n\nStatements that answer the question:"
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
        response, log_probs, num_tokens = self.get_response(messages)
        self.update_state(question, response)
        return response
