import benepar, spacy
import matplotlib.pyplot as plt
import os
import pandas as pd
import json
import tqdm
import plotly.graph_objects as go
import plotly.express as px


generated_data_path = "/scratch/datasets/MIMIC-CXR/merged_instructions_train.json"
# generated_data_path = "context.json"
section = "context"



# nlp = spacy.load('en_core_web_md')
# doc = nlp("The time for action is now. It's never too late to do something.")
def find_root_verb_and_its_dobj(tree_root):
    # first check if the current node and its children satisfy the condition
    if tree_root.pos_ == "VERB":
        for child in tree_root.children:
            if child.dep_ == "dobj" and child.pos_ == "NOUN":
                return tree_root.lemma_, child.lemma_
        return tree_root.lemma_, None
    # if not, check its children
    for child in tree_root.children:
        return find_root_verb_and_its_dobj(child)
    # if no children satisfy the condition, return None
    return None, None

def find_root_verb_and_its_dobj_in_string(s):
    doc = nlp(s)
    first_sent = list(doc.sents)[0]
    return find_root_verb_and_its_dobj(first_sent.root)



with open(generated_data_path, 'r') as fin:
    gpt4_machine_generated_tasks = json.load(fin)

labels = ["No Context", "Medical Record", "Template", "Previous Visit", "Revision"]
count = {"history":0, "template":0, "comparison":0, "correction":0, "no_ctx":0}
for k in gpt4_machine_generated_tasks["data"].keys():
    if "history" in k:
        count["history"] += 1
    elif "correction" in k:
        count["correction"] += 1
    elif "template" in k:
        count["template"] += 1
    elif "comparison" in k:
        count["comparison"] += 1
    else:
        count["no_ctx"] += 1

# Also count test set:
with open(os.path.join(os.path.dirname(generated_data_path), "instructions_test.json")) as f:
    generated_data = json.load(f)
for item in generated_data["data"].items():
    count["no_ctx"] += 1
for task in ["history", "comparison", "template", "correction"]:
    with open(os.path.join(os.path.dirname(generated_data_path), f"{task}_instructions_test.json")) as f:
        generated_data = json.load(f)
    for item in generated_data["data"].items():
        count[task] += 1


values = [
    count["no_ctx"],
    count["history"],
    count["template"],
    count["comparison"],
    count["correction"],
]

fig = go.Figure(data=[go.Pie(
    labels=labels,
    values=values,
    # textinfo='label+value+percent',
    texttemplate="%{label}<br>%{value} (%{percent})",
    insidetextorientation='radial',
    textfont={"size": 11}
)])
# fig.show()
fig.write_html(f"task_output.html")
fig.write_image(f"task.pdf")


exit()

instruction_outputs = {}
for task, generated_list in gpt4_machine_generated_tasks.items():
    idx = 0
    for sentence in generated_list:
        instruction_outputs[f"{task}_{idx}"] = sentence
        idx += 1


# instruction_outputs = {k:v[section] for k,v in gpt4_machine_generated_tasks["data"].items()}  # if you are interested in studying the instructions, please change the task key
instruction_outputs = dict(list(instruction_outputs.items())[::50])
print(len(instruction_outputs))

raw_phrases = []
for key, out in tqdm.tqdm(instruction_outputs.items()):
    try:
        if "template" in key: task = "(2)"
        elif "comparison" in key: task = "(3)"
        elif "history" in key: task = "(4)"
        elif "correction" in key: task = "(1)"
        verb, noun = find_root_verb_and_its_dobj_in_string(out)
        raw_phrases.append({
            "task": task,
            "verb": verb,
            "noun": noun,
            "instruction_output": out
        })
    except Exception as e:
        print(e)
        print(out)

len(raw_phrases)
raw_phrases = pd.DataFrame(raw_phrases)
raw_phrases.to_csv(f'{section}.csv')
raw_phrases = pd.read_csv(f'{section}.csv')
raw_phrases = pd.DataFrame(raw_phrases)




phrases = pd.DataFrame(raw_phrases).dropna()

# count_list = phrases[["verb", "noun"]].groupby(["verb", "noun"]).size().sort_values(ascending=False)
# top_verbs = phrases[["verb"]].groupby(["verb"]).size().nlargest(20).reset_index()
# df = phrases[phrases["verb"].isin(top_verbs["verb"].tolist())]
# df = df.groupby(["verb", "noun"]).size().reset_index().rename(columns={0: "count"}).sort_values(by=["count"], ascending=False)
# all_df = df.groupby("verb").apply(lambda x: x.sort_values("count", ascending=False).head(4)).reset_index(drop=True)



all_df = pd.DataFrame(columns=["task","verb","noun","count"])
for t in set(phrases["task"]):
    task_phrases = phrases[phrases["task"] == t]
    count_list = task_phrases[["verb", "noun"]].groupby(["verb", "noun"]).size().sort_values(ascending=False)
    top_verbs = task_phrases[["verb"]].groupby(["verb"]).size().nlargest(20).reset_index()
    task_df = task_phrases[task_phrases["verb"].isin(top_verbs["verb"].tolist())]
    task_df = task_df.groupby(["verb", "noun"]).size().reset_index().rename(columns={0: "count"}).sort_values(by=["count"], ascending=False)
    task_df = task_df.groupby("verb").apply(lambda x: x.sort_values("count", ascending=False).head(10)).reset_index(drop=True)
    task_df["task"] = [t]*len(task_df)
    all_df = pd.concat([all_df, task_df])

print(all_df)



# all_df = all_df[all_df["count"] > 5]
fig = px.sunburst(all_df, path=['task', 'verb', 'noun'], values='count')
# fig.update_layout(uniformtext=dict(minsize=10, mode='hide'))
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    font_family="Times New Roman",
)
fig.show()
fig.write_html(f"{section}_output.html")
fig.write_image(f"{section}.pdf")