import os
import re
import h5py
import torch
import torch.nn.functional as F

import numpy as np

#from transformers import AutoModelForCausalLM, AutoTokenizer

def split_into_sentences(text: str) -> list[str]:
    """
    Split the text into sentences.

    If the text contains substrings "<prd>" or "<stop>", they would lead 
    to incorrect splitting because they are used as markers for splitting.

    :param text: text to be split into sentences
    :type text: str

    :return: list of sentences
    :rtype: list[str]
    """
    
    alphabets= "([A-Za-z])"
    prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
    suffixes = "(Inc|Ltd|Jr|Sr|Co)"
    starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
    acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
    websites = "[.](com|net|org|io|gov|edu|me)"
    digits = "([0-9])"
    multiple_dots = r'\.{2,}'
        
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = [s.strip() for s in sentences]
    if sentences and not sentences[-1]: sentences = sentences[:-1]
    return sentences

def preprocess_report(report, text_sample_mode, requested_headings, simple=False, sampling=True, include_heading=False, prefix="Clinical EEG report: "):
    """
    Filter report depending on requested headings.
    INPUT
    report: str
    text_sample_mode: str (report, paragraph, sentence)
    requested_headings: list of str (CLINICAL_HISTORY, ...,). ['all'] includes every heading
    mode: Boolean (True: sampling, False: readout). Whether to return 1 paragraph/sentence, or all of them.
    OUTPUT
    list of strings, and boolean indicating whether its returning the requested headings
    """
    
    delimiters = [
    "DATE OF STUDY",
    "START DATE OF STUDY",
    "END DATE OF STUDY",
    "REASON FOR STUDY",
    "DAY",
    "RECORDING ENVIRONMENT",
    "EVENT",
    "EEG TYPE",
    "HISTORY", 
    "CLINICAL HISTORY",
    "MEDICATIONS", 
    "AED",
    "INTRODUCTION", 
    "SEDATION", 
    "SEIZURE TIME",
    "SEIZURES",
    "ABNORMAL DISCHARGES",
    "TECHNIQUE",
    "TECHNICAL ISSUE",
    "TECHNICAL PROBLEM",
    "TECHNICAL DIFFICULTIES",
    "DESCRIPTION OF THE RECORD", 
    "DESCRIPTION OF RECORD",
    "DESCRIPTION",
    "EEG BACKGROUND",
    "BACKGROUND",
    "EPILEPTIFORM ACTIVITY",
    "OTHER PAROXYSMAL ACTIVITY (NON-EPILEPTIFORM)",
    "ACTIVATION PROCEDURES", 
    "INDUCTION PROCEDURES",
    "EVENTS", 
    "HEART RATE", 
    "HR",
    "CONDITIONS OF THE RECORDING",
    "RANDOM WAKEFULNESS AND SLEEP",
    "IMPRESSION", 
    "CLINICAL CORRELATION", 
    "CORRELATION",
    "CONCLUSION",
    "SUMMARY",
    "SUMMARY OF FINDINGS",
    "DIAGNOSIS", 
    "INTERPRETATION"
    ]
    
    requests_found = False
    
    report = report.replace("\u2028", "\n")
    
    # Simple mode just dumps the report back with minimal preprocessing
    if simple:
        report = report.replace("Clinical EEG report: ", "")
        return [prefix + " " + report.replace("\n", " ").strip()], requests_found
    
    # Catch rare, very short unstructured reports
    # if len(report) < 200:
    #     return [prefix + " " + report.replace("\n", " ").strip()], requests_found
    
    # Replace double space with single space as this can mess up header detection.
    report = report.replace("  ", " ")
    
    # Remove any random text from report that may precede a header.
    preamble = report.split("\n")[0]
    valid_header = any(s in preamble.upper() for s in delimiters)
    if valid_header == False:
        report = report.replace(preamble, "")
    
    # First, split into paragraphs.
    pattern = '|'.join(map(re.escape, delimiters))

    sections = re.split(f'(^|\n)({pattern})', report, flags=re.IGNORECASE|re.MULTILINE)

    # Remove empty strings from the sections list
    sections = [section.strip() for section in sections if section.strip()]
    
    # If the first section is actually random text and not a header, it'll complicate the rest of the code. 
    # Thus, delete sections until we find a header.
    valid_start = False
    i = 0
    while (valid_start==False) and (len(sections)>0):
        valid_start = sections[i].upper().startswith(tuple(delimiters))
        if valid_start == False:
            sections.pop(0)

    # Now that we have a valid header, we assume the repeating structure:
    # header -> content/paragraph -> header ...
    headings, paragraphs = [], []
    for i in range(0, len(sections), 2):
        heading = sections[i].strip()
        content = sections[i+1] if i+1 < len(sections) else ""
        content = content.replace(':', '').strip()
        if include_heading:
            paragraph = f"{heading}:" + "\n" + content
        else:
            paragraph = content
        headings.append(heading)
        paragraphs.append(paragraph)
        
    # If no paragraphs are detected, we'll move from paragraph to sentence
    if len(paragraphs) == 0: 
        all_sentences = split_into_sentences(report.replace("\n", " "))
        if sampling:
            sentence = np.random.choice(all_sentences)
            return [prefix + sentence.replace("  ", " ").strip()], requests_found
        else:
            return [prefix + sentence.replace("  ", " ").strip() for sentence in all_sentences], requests_found
            
    # Reduce to include paragraphs only with requested headings
    if requested_headings == ["all"]:
        mask = np.array([True]*len(paragraphs))
    else:
        #mask = np.isin(headings, requested_headings)
        mask = []
        for head in headings:
            mask.append(any([req_head in head.upper() for req_head in requested_headings]))
        mask = np.array(mask)
        
    requests_found = True if mask.sum() else False
    if mask.sum() == 0: # If no requested headings are found, select a random one.
        # assert sampling==True, "No requested headings found while in readout mode."
        if sampling:
            sample = np.random.randint(0, len(paragraphs))
            mask[sample] = True
        else:
            print("No requested headings found while in readout mode. Return empty list.")
            return [], requests_found
    
    selected_paragraphs = np.array(paragraphs)[mask]
    selected_headings = np.array(headings)[mask]
           
    if text_sample_mode == "report":
        reduced_report = f"""{prefix}"""
        for s in selected_paragraphs:
            reduced_report += " " + s   
        return [reduced_report.replace("\n", " ").strip()], requests_found
    
    if not sampling: # return all relevant paragraphs or sentences
        if text_sample_mode == "paragraph":
            return [prefix + paragraph.replace("\n", " ").replace("  ", " ").strip() for paragraph in selected_paragraphs], requests_found
        elif text_sample_mode == "sentence":
            all_sentences = []
            for paragraph in selected_paragraphs:
                paragraph_sentences = split_into_sentences(paragraph.replace("\n", " "))
                all_sentences.extend([prefix + sentence.replace("  ", " ").strip() for sentence in paragraph_sentences])
            return all_sentences, requests_found
        
    assert sampling==True   
    # for both paragraph or sentence modes, sample one paragraph
    index = np.random.randint(0, len(selected_paragraphs))
    paragraph = selected_paragraphs[index]
    heading = selected_headings[index]    
        
    if text_sample_mode == "paragraph":
        paragraph = prefix + " " + paragraph
        return [paragraph.replace("\n", " ").replace("  ", " ").strip()], requests_found
    
    elif text_sample_mode == "sentence":
        sentence = np.random.choice(split_into_sentences(paragraph.replace("\n", " ")))
        sentence = sentence.replace(heading, "")
        sentence = prefix + " " + sentence
        return [sentence.replace("  ", " ").strip()], requests_found
            
class_prompts = {}
class_prompts["class-prompt_pathology_0"] = [
"Normal EEG.",
"No pathology present.",
"No abnormalities.", 
"Normal routine EEG.",
"Normal awake record.",
"Normal EEG record.",
"This EEG is normal.",
"This is a normal EEG.", 
"This EEG is within normal limits",
"Normal awake EEG.",
"Normal asleep EEG.",
"Normal awake and asleep EEG.",
"Normal EEG in wakefulness and drowsiness.",
"No pathology.",
"EEG shows no pathology.",
"No abnormalities.",
"No abnormalities observed.",
"EEG shows no abnormalities.", 
"No clinical events detected.", 
"No indications of pathology observed.", 
"The EEG is normal."] 

class_prompts["class-prompt_pathology_1"] = [
"Abnormal EEG.",
"Pathology present.",
"Abnormalities observed.", 
"Markedly abnormal EEG.",
"Abnormal awake record.",
"Abnormal EEG record.",
"This EEG is abnormal.",
"This is an anormal EEG.", 
"This EEG is mildly abnormal.",
"Abnormal awake EEG.",
"Abnormal asleep EEG.",
"Abnormal awake and asleep EEG.",
"Abnormal EEG in wakefulness and drowsiness.",
"Abnormal EEG due to:",
"Abnormal EEG for a subject of this age due to:",
"Abnormalities in the EEG.",
"Abnormalities observed.",
"EEG shows abnormalities.", 
"Clinical events detected.",
"Indications of pathology observed.",
"The EEG is pathologically abnormal"]


def LLM_summarization(token):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch
    import chardet
    
    custom_cache_dir = '/path/to/customcache/for/models/'
    url = "meta-llama/Meta-Llama-3-8B-Instruct"

    max_length = 1000
    trunc = True
    device = "cuda:0"

    model = AutoModelForCausalLM.from_pretrained(url, cache_dir = custom_cache_dir, torch_dtype=torch.bfloat16,
                                                token="token").to(device)
    tokenizer = AutoTokenizer.from_pretrained(url, cache_dir = custom_cache_dir, token="token",
                                            trunctation_side="left", max_length=max_length, truncation=trunc)
    tokenizer.truncation_side = "left" # Required!
    
    
    path = "/path/to/reports_as_txt_files/"
    save_path = "/path/to/save_directory/"
    reports = os.listdir(path)

    number_of_tokens = []

    for i, report in enumerate(reports):
        print("Report #", i, " loading:", path+report)
        
        try:
            with open(path + report,'r',newline='') as rf:
                text = rf.read()
        except:
            with open(path + report, 'rb') as file:
                raw_data = file.read(10000)  # Read the first 10000 bytes to guess the encoding
                result = chardet.detect(raw_data)
            try:
                with open(path + report, 'r', encoding=result["encoding"]) as rf:
                    text = rf.read()
            except:
                print("*"*40)
                print("SKIPPING!")
                print("*"*40)
                continue
        
        text = text.replace("\n", " ")
        
        messages = [
            {"role": "system", "content": "You are a helpful assistant adhering strictly to instructions."},
            {"role": "user", "content": f"""You are provided a clinical EEG report. Please create an EXTREMELY BRIEF summary of the report into a SINGLE, SHORT sentence. It should ONLY state whether the EEG is normal or pathologically abnormal and why. It is CRUCIAL you are EXTREMELY BRIEF. Include no text beyond this short summary itself.

        If the EEG is normal, do NOT further explicitly state that no abnormalities were observed of any kind. This includes mentioning the absence of pathology or epileptiform features etc.

        Here is the report:
                
        {text}"""},
        ]

        input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt",
                                max_length=max_length, truncation=True,).to(device)

        terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        outputs = model.generate(
            input_ids,
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=50,
            eos_token_id=terminators,
            do_sample=False,
        )
        response = outputs[0][input_ids.shape[-1]:]
        decoded_response = tokenizer.decode(response, skip_special_tokens=True)

        # Count output tokens
        output_tokens = tokenizer.encode(decoded_response)
        output_token_count = len(output_tokens)
        
        number_of_tokens.append(output_token_count)
        
        with open(save_path + report,'w',newline='') as wf:
            wf.write("Clinical EEG report: " + decoded_response)

        # print(decoded_response)

        # print("-"*18)
        
    np.save(save_path + "tokens.npy", number_of_tokens)
    
#LLM_summarization("") #Requires token!
