from models.rangemodel import LightningRangeModel
import torch
import numpy as np
from config import cfg
import matplotlib.pyplot as plt
from copy import deepcopy
from PIL import Image
import fire


def get_range_map(
    model_path,
    label,
    threshold=0.95,
    env_cov=False,
    taxonomy_level="species",
    llm_type="Llama-2-7b-hf",
    device="cuda",
    save_img=False,
    mask_path="../sinr/data/masks/ocean_mask.npy",
):
    range_model = LightningRangeModel.load_from_checkpoint(model_path).to(device).eval()
    text_embeddings = np.load(
        f"data/text_embeddings/{taxonomy_level}_{llm_type}.npy", allow_pickle=True
    )
    text_embeddings = text_embeddings[()][label]
    text_embeddings = torch.tensor(text_embeddings).to(device)
    if env_cov:
        env_cov = torch.tensor(np.load(cfg.data.env_cov_path)).unsqueeze(0).to(device)
        pred = range_model(text_embeddings, env_cov)
    else:
        pred = range_model(text_embeddings)

    pred = pred.squeeze(0).sigmoid().detach().cpu().numpy()
    pred[pred < np.quantile(pred, threshold)] = 0
    if save_img:
        a = np.load(mask_path)
        mask_inds = np.where(a.reshape(-1) == 0)[0]
        cmap = plt.get_cmap("plasma")
        rgba_img = cmap(pred)
        rgb_img = np.delete(rgba_img, 3, 2)
        rgb_img_2 = deepcopy(rgb_img).reshape(-1, 3)
        rgb_img_2[mask_inds] = [1, 1, 1]
        rgb_img = rgb_img_2.reshape(rgb_img.shape)
        Image.fromarray((rgb_img * 255).astype(np.uint8)).save(
            f"result_plots/{label}_range_map.png", dpi=(300, 300)
        )


if __name__ == "__main__":
    fire.Fire(get_range_map)
