from models.rangemodel import LightningRangeModel
from data_utils.datasets import OccuranceDataModule
import torch
import numpy as np
import pandas as pd
from utils.pwcd_dist import get_all_metrics
from config import cfg
import fire


def eval_species_range_prediction(
    model_path, mask_path, taxonomy_level, llm_type, device
):
    """
    model_path: string
    mask_path: string
    taxonomy_level: string (one of 'class', 'order', 'family', 'genus', 'species')
    llm_type: string (one of 'Llama-2-7b-hf', 'Llama-2-13b-hf', 'Llama-2-70b-hf')
    device: string (one of 'cpu', 'cuda')
    """
    if cfg.data.env_cos:
        training_module = OccuranceDataModule(
            cfg.data.text_embeddings_path,
            cfg.eval.test_parquet_path,
            env_cov_path=cfg.data.env_cov_path,
            bins=cfg.model.img_size,
            batch_size=cfg.train.batch_size,
            shuffle=cfg.train.shuffle,
            num_workers=cfg.train.num_workers,
        )
    else:
        training_module = OccuranceDataModule(
            cfg.data.text_embeddings_path,
            cfg.eval.test_parquet_path,
            bins=cfg.model.img_size,
            batch_size=cfg.train.batch_size,
            shuffle=cfg.train.shuffle,
            num_workers=cfg.train.num_workers,
        )

    a = np.load(mask_path)
    mask_inds = np.where(a.reshape(-1) == 0)[0]

    range_model = LightningRangeModel.load_from_checkpoint(model_path).to(device).eval()
    text_embeddings = np.load(
        f"data/text_embeddings/{taxonomy_level}_{llm_type}.npy", allow_pickle=True
    )
    unique_species = list(text_embeddings[()].keys())
    test_parquet = pd.read_parquet(cfg.data.test_parquet_path)
    unique_species_test = list(test_parquet[taxonomy_level].unique())
    map_scores = []
    pwcd_scores = []

    for i, batch in enumerate(training_module.train_dataloader()):
        if unique_species_test[i] in unique_species:
            if cfg.env_cov:
                image, text, target = batch
                image = image.to(device)
                text = text.to(device)
            else:
                text, target = batch
                image = None
                text = text.to(device)
            target = target.squeeze().detach().cpu().numpy()
            pred = range_model(text, image).squeeze().detach().cpu().numpy()
            pred[pred < np.quantile(pred, cfg.eval.threshold)] = 0
            map_score, pwcd_score = get_all_metrics(pred, target, mask_inds)
            map_scores.append(map_score)
            pwcd_scores.append(pwcd_score)

    print(f"MAP score: {np.nanmean(map_scores)}")
    print(f"PWCD score: {np.nanmean(pwcd_scores)}")


def eval_zero_shot(model_path, mask_path, device):
    """
    model_path: string
    mask_path: string
    device: string (one of 'cpu', 'cuda')
    """
    if cfg.data.env_cos:
        training_module = OccuranceDataModule(
            cfg.eval.text_embeddings_path,
            cfg.eval.unseen_parquet_path,
            env_cov_path=cfg.data.env_cov_path,
            bins=cfg.model.img_size,
            batch_size=cfg.train.batch_size,
            shuffle=cfg.train.shuffle,
            num_workers=cfg.train.num_workers,
        )
    else:
        training_module = OccuranceDataModule(
            cfg.eval.text_embeddings_path,
            cfg.eval.unseen_parquet_path,
            bins=cfg.model.img_size,
            batch_size=cfg.train.batch_size,
            shuffle=cfg.train.shuffle,
            num_workers=cfg.train.num_workers,
        )

    a = np.load(mask_path)
    mask_inds = np.where(a.reshape(-1) == 0)[0]

    range_model = LightningRangeModel.load_from_checkpoint(model_path).to(device).eval()
    map_scores = []
    pwcd_scores = []

    for i, batch in enumerate(training_module.train_dataloader()):
        if cfg.env_cov:
            image, text, target = batch
            image = image.to(device)
            text = text.to(device)
        else:
            text, target = batch
            image = None
            text = text.to(device)
        target = target.squeeze().detach().cpu().numpy()
        pred = range_model(text, image).squeeze().detach().cpu().numpy()
        pred[pred < np.quantile(pred, cfg.eval.threshold)] = 0
        map_score, pwcd_score = get_all_metrics(pred, target, mask_inds)
        map_scores.append(map_score)
        pwcd_scores.append(pwcd_score)

    print(f"MAP score: {np.nanmean(map_scores)}")
    print(f"PWCD score: {np.nanmean(pwcd_scores)}")


if __name__ == "__main__":
    fire.Fire(eval_species_range_prediction)
