from ..models.rangemodel import LightningRangeModel
import torch
import numpy as np
import pandas as pd
from ..config import cfg
from PIL import Image
import matplotlib.pyplot as plt
from copy import deepcopy


def get_presence_map_gt(
    obs,
    species,
    bins=[900, 1800],
    save_img=False,
    mask_path="../sinr/data/masks/ocean_mask.npy",
):
    """
    obs: pandas dataframe with columns species, decimalLatitude, decimalLongitude
    species: string
    bins: list of ints
    save_img: bool
    mask_path: string
    """
    obs = obs[(obs["species"] == species)]
    hist, *_ = np.histogram2d(
        -obs["decimalLatitude"],
        obs["decimalLongitude"],
        bins=bins,
        range=[[-90, 90], [-180, 180]],
    )
    hist[hist > 0] = 1
    if save_img:
        a = np.load(mask_path)
        mask_inds = np.where(a.reshape(-1) == 0)[0]
        cmap = plt.get_cmap("plasma")
        rgba_img = cmap(hist)
        rgb_img = np.delete(rgba_img, 3, 2)
        rgb_img_2 = deepcopy(rgb_img).reshape(-1, 3)
        rgb_img_2[mask_inds] = [1, 1, 1]
        rgb_img = rgb_img_2.reshape(rgb_img.shape)
        Image.fromarray((rgb_img * 255).astype(np.uint8)).save(
            f"{species} gt.png", dpi=(300, 300)
        )
    return hist


def get_pred_range_map(
    range_model,
    label,
    env_cov=False,
    taxonomy_level="species",
    llm_type="Llama-2-70b-hf",
    device="cuda",
    text_embeddings=None,
    threshold=0.95,
):
    """
    range_model: LightningRangeModel
    label: string (species name)
    env_cov: bool
    taxonomy_level: string (one of 'class', 'order', 'family', 'genus', 'species')
    llm_type: string (one of 'Llama-2-7b-hf', 'Llama-2-13b-hf', 'Llama-2-70b-hf')
    device: string (one of 'cpu', 'cuda')
    """
    if text_embeddings is None:
        text_embeddings = np.load(
            f"data/text_embeddings/{taxonomy_level}_{llm_type}.npy", allow_pickle=True
        )
    text_embeddings = text_embeddings[()][label]
    text_embeddings = torch.tensor(text_embeddings).to(device)
    if env_cov:
        env_cov = torch.tensor(np.load(cfg.data.env_cov_path)).unsqueeze(0).to(device)
        pred = range_model(text_embeddings, env_cov)
    else:
        pred = range_model(text_embeddings)

    pred = pred.squeeze(0).sigmoid().detach().cpu().numpy()
    pred[pred < np.quantile(pred, threshold)] = 0
    return pred
