import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
import matplotlib.pyplot as plt
from copy import deepcopy
from models import Attn

pred_global = None

land_mask = np.load("data/ocean_mask.npy")
mask_inds = np.where((land_mask).reshape(-1) == 0)[0]


class RangeModel(nn.Module):
    def __init__(self):
        super(RangeModel, self).__init__()
        self.cross_attn = Attn(128, 8192)
        self.upsample = nn.Upsample(size=(1002, 2004), mode="bilinear")
        self.out = nn.Conv2d(128, 1, 1, bias=False)
        self.x = None

    def forward(self, text):
        x = self.cross_attn(self.x, text)
        x = rearrange(x, "b (h w) d -> b d h w", h=225)
        x = self.upsample(x)
        x = self.out(x)
        return x


model = RangeModel()
model.load_state_dict(
    torch.load("model/demo_model.pt", map_location=torch.device("cpu"))
)
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).float()
model.eval()

species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
genus = np.load("data/genus_70b.npy", allow_pickle=True)
family = np.load("data/family_70b.npy", allow_pickle=True)

species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
genus_list = list(genus[()].keys())
family_list = list(family[()].keys())


def update_fn(val):
    if val == "Class":
        return gr.Dropdown(label="Name", choices=class_list, interactive=True)
    elif val == "Order":
        return gr.Dropdown(label="Name", choices=order_list, interactive=True)
    elif val == "Family":
        return gr.Dropdown(label="Name", choices=family_list, interactive=True)
    elif val == "Genus":
        return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
    elif val == "Species":
        return gr.Dropdown(label="Name", choices=species_list, interactive=True)


def text_fn(taxon, name):
    global pred_global
    if taxon == "Class":
        text_embeds = clas[()][name]
    elif taxon == "Order":
        text_embeds = order[()][name]
    elif taxon == "Family":
        text_embeds = family[()][name]
    elif taxon == "Genus":
        text_embeds = genus[()][name]
    elif taxon == "Species":
        text_embeds = species[()][name]

    text_embeds = torch.tensor(text_embeds)
    preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
    pred_global = deepcopy(preds)
    preds[preds < np.quantile(preds, 0.95)] = 0
    cmap = plt.get_cmap("plasma")

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    rgb_img_2 = deepcopy(rgb_img).reshape(-1, 3)
    rgb_img_2[mask_inds] = [1, 1, 1]
    rgb_img = rgb_img_2.reshape(rgb_img.shape)
    # rgb_img[mask_inds] = [1, 1, 1]
    # return gr.Image(preds, label="Predicted Heatmap", visible=True)
    return rgb_img


def reset_fn():
    global pred_global
    cmap = plt.get_cmap("plasma")

    rgba_img = cmap(pred_global)
    rgb_img = np.delete(rgba_img, 3, 2)
    rgb_img_2 = deepcopy(rgb_img).reshape(-1, 3)
    rgb_img_2[mask_inds] = [1, 1, 1]
    rgb_img = rgb_img_2.reshape(rgb_img.shape)
    # rgb_img[mask_inds] = [1, 1, 1]
    return rgb_img


def thresh_fn(val):
    global pred_global
    preds = deepcopy(pred_global)
    # preds[preds<val] = 0
    # preds[preds>=val] = 1
    preds[preds < np.quantile(preds, val)] = 0
    cmap = plt.get_cmap("plasma")

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    rgb_img_2 = deepcopy(rgb_img).reshape(-1, 3)
    rgb_img_2[mask_inds] = [1, 1, 1]
    rgb_img = rgb_img_2.reshape(rgb_img.shape)
    # rgb_img[mask_inds] = [1, 1, 1]
    return rgb_img


with gr.Blocks() as demo:
    gr.Markdown(
        """
    # Hierarchical Species Distribution Model!
    This model predicts the distribution of species based on geographic, environmental, and natural language features.
    """
    )
    with gr.Row():
        inp = gr.Dropdown(
            label="Taxonomic Hierarchy",
            choices=["Class", "Order", "Family", "Genus", "Species"],
        )
        out = gr.Dropdown(label="Name", interactive=True)
        inp.change(update_fn, inp, out)

    with gr.Row():
        check_button = gr.Button("Run Model")

    with gr.Row():
        slider = gr.Slider(
            minimum=0.9,
            maximum=1,
            step=0.01,
            default=0.95,
            label="Confidence Threshold (Set after running model)",
        )

    with gr.Row():
        reset_button = gr.Button("Reset to Raw Predictions")

    with gr.Row():
        pred = gr.Image(label="Predicted Heatmap", visible=True)

    check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
    slider.change(thresh_fn, slider, outputs=pred)
    reset_button.click(reset_fn, outputs=[pred])

demo.launch()
