"""
Queries LM for priors
"""
from typing import List, Tuple

import argparse
import os
import numpy as np
from tqdm import tqdm
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from utils import (
    gpt3_score_prompt, save_gpt3_result, load_gpt3_cache,
    sunrgb_items, sunrgb_items_r,
    rednet_rooms, hm3d_rooms, hm3d_items
)
from RedNet_data import SUNRGBD
from scipy.special import kl_div
from scipy.linalg import block_diag
import random
import torch

import math

sns.set_theme()

DEBUG = False

engine = "text-davinci-002"
gpt_scoring_cache_fn = f"GPT3_scoring/{engine}.jsonl"
gpt3_scoring_cache = load_gpt3_cache(gpt_scoring_cache_fn)
prompt_examples_reservoir = {
    "obj_cooccurs_with_obj": {
        "plausible": [
            "The picture cooccurs with the wall: plausible",
            "The tv cooccurs with the sofa: plausible",
            "The counter cooccurs with the sink: plausible",
            "The floor mat cooccurs with the floor: plausible",
        ],
        "implausible": [
            "The shower cooccurs with the sofa: implausible",
            "The fridge cooccurs with the shower curtain: implausible",
            "The pillow cooccurs with the toilet: implausible",
            "The bookshelf cooccurs with the fridge: implausible",
        ],
    },
    "room_has_obj": {
        "plausible": [
            "A bathroom has toilet: plausible",
            "A bedroom has bed: plausible",
            "A living room has tv: plausible",
            "A kitchen has fridge: plausible",
        ],
        "implausible": [
            "An office has night stand: implausible",
            "A kitchen has shower curtain: implausible",
            "A bathroom has sofa: implausible",
            "A bedroom has bathtub: implausible",
        ],
    },
    "obj_in_room": {
        "plausible": [
            "A toilet is in bathroom: plausible",
            "A bed is in bedroom: plausible",
            "A tv is in living room: plausible",
            "A fridge is in kitchen: plausible",
        ],
        "implausible": [
            "A night stand is in office: implausible",
            "A shower curtain is in kitchen: implausible",
            "A sofa is in bathroom: implausible",
            "A bathtub is in bedroom: implausible",
        ],
    },
    "obj_lookslike": {
        "plausible": [
            "The table looks like the desk: plausible",
            "The curtain looks like the shower curtain: plausible",
            "The bed looks like the sofa: plausible",
            "The sofa looks like the chair: plausible",
        ],
        "implausible": [
            "The dresser looks like the lamp: implausible",
            "The towel looks like the toilet: implausible",
            "The shelves looks like the pillow: implausible",
            "The paper looks like the bathtub: implausible",
        ],
    }
}


def get_gpt3_cooccurrences(output_logits_path, items, prompt_type):
    breakpoint()
    output_fig_path = output_logits_path.replace(".npy", ".png")
    rooms = hm3d_rooms
    if prompt_type == "obj_cooccurs_with_obj" or prompt_type == "obj_lookslike":
        cooccuring_items = items
    elif prompt_type == "room_has_obj" or prompt_type == "obj_in_room":
        cooccuring_items = rooms
    else:
        raise NotImplementedError
    if not os.path.exists(output_logits_path):
        cooccur_matrix = np.zeros([len(items), len(cooccuring_items)], dtype=np.float)
        for obj_idx, obj in enumerate(tqdm(items)):
            obj_nl = obj.replace('_', ' ')
            for co_item_idx, co_item in enumerate(cooccuring_items):
                if prompt_type == "obj_cooccurs_with_obj":
                    curr_prompt = f"The {obj.replace('_', ' ')} cooccurs with the {co_item.replace('_', ' ')}: "
                elif prompt_type == "obj_lookslike":
                    curr_prompt = f"The {obj.replace('_', ' ')} looks likes the {co_item.replace('_', ' ')}: "
                elif prompt_type == "room_has_obj":
                    if co_item[0] in ["a", "e", "i", "o", "u"]:
                        co_item = f"An {co_item}"
                    else:
                        co_item = f"A {co_item}"
                    curr_prompt = f"{co_item} has {obj.replace('_', ' ')}: "
                elif prompt_type == "obj_in_room":
                    if obj[0] in ["a", "e", "i", "o", "u"]:
                        obj = f"An {obj}"
                    else:
                        obj = f"A {obj}"
                    curr_prompt = f"{obj} is in {co_item}: "
                prev_prompts = [ex for ex in prompt_examples_reservoir[prompt_type]["plausible"] if ex != (curr_prompt + "plausible")][:3] + [ex for ex in prompt_examples_reservoir[prompt_type]["implausible"] if ex != (curr_prompt + "implausible")][:3]
                prev_prompts = '\n'.join(prev_prompts)
                cond_scores, new_results = gpt3_score_prompt(
                    engine,
                    f"{prev_prompts}\n{curr_prompt}",
                    ["plausible", "implausible"],
                    gpt3_scoring_cache,
                )
                save_gpt3_result(gpt_scoring_cache_fn, new_results)
                # p(col|row)
                cooccur_matrix[obj_idx][co_item_idx] += np.exp(np.array(cond_scores[0])) / np.exp(np.array(cond_scores)).sum()

        cooccur_rates = cooccur_matrix
        np.save(output_logits_path, cooccur_rates)
    else:
        cooccur_rates = np.load(output_logits_path)

    cooccur_data = pd.DataFrame(data=cooccur_rates, index=items, columns=cooccuring_items)
    # save results....
    fig, ax = plt.subplots()
    # visualize and save
    ax = sns.heatmap(cooccur_data, ax=ax)
    plt.xticks(np.arange(len(items))+0.5, items, fontsize=9)
    plt.yticks(np.arange(len(cooccuring_items))+0.5, cooccuring_items, fontsize=9)
    fig.tight_layout()
    fig.savefig(output_fig_path)

    breakpoint()
    fig, axs = plt.subplots(10,3,figsize=(20,30))
    for i in range(3):
        for j in range(10):
            # axs[i][j] = 
            if 5*i+j >= len(cooccur_data):
                break
            obj = cooccur_data.index[10*i+j]
            # sort
            data = cooccur_data.loc[obj].sort_values(ascending=False)
            # cooccur_data
            axs[j][i].plot(range(len(data)), data)
            axs[j][i].set_xticks(range(len(data)))
            axs[j][i].set_xticklabels(data.keys().tolist(), rotation=90)
            axs[j][i].set_title(obj)

    """
    fig, axs = plt.subplots(10, 4, figsize=(20,40))
    for i in range(4):
        for j in range(10):
            # axs[i][j] = 
            if 10*i+j >= len(cooccur_data.columns):
                break
            obj = cooccur_data.columns[10*i+j]
            # sort
            if prompt_type == "room_has_obj":
                data = cooccur_data.loc[:,obj].sort_values(ascending=False)
            else:
                data = cooccur_data.loc[obj].sort_values(ascending=False)
            # cooccur_data
            axs[j][i].plot(range(len(data)), data)
            axs[j][i].set_xticks(range(len(data)))
            axs[j][i].set_xticklabels(data.keys().tolist(), rotation=90)
            axs[j][i].set_title(obj)
    """
    fig.suptitle("p(*|obj) -- biggest indicators of obj")
    fig.tight_layout()
    fig.savefig(output_fig_path)

    return cooccur_rates



def main(args):
    get_gpt3_cooccurrences("lm_priors/gpt3_room_obj_plausimplaus_hm3d_cooccurence.npy", items=hm3d_items, prompt_type="obj_in_room")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Make image generalization split')
    parser.add_argument('--get-cooccurrences', action='store_true', default=False,
                        help='get existing co-occurrence statistics (rather than making splits)')
    parser.add_argument('--domain', type=str,choices=['sunrgb', 'hm3d', 'objnav'],
                        help='what type of co-occurrence statistics to get')
    parser.add_argument('--cooccur-type', type=str, default='objobj', choices=['objobj', 'objroom', 'obj_lookslike'],
                        help='what type of co-occurrence statistics to get')
    args = parser.parse_args()
    main(args)