from collections import defaultdict
import logging
from typing import Dict

import torch
import pandas as pd
import numpy as np

from torch.utils.data import DataLoader

from text_ood.utils.dataset_util import load_dataset_embeddings, load_dataset_embeddings_ragged

# load_dataset_embeddings = load_dataset_embeddings_ragged


logger = logging.getLogger(__name__)


class OODEvaluator:
    def __init__(self, id_dataset, out_datasets, metrics, logger, batch_size, embedding_type, config, device=None):
        self.in_dataset = id_dataset
        self.out_datasets = out_datasets
        self.metrics = metrics
        self.logger = logger
        self.device = device
        self.batch_size = batch_size
        self.embedding_type = embedding_type
        self.config = config
        self.num_max_samples = config.num_max_eval_samples  # Max samples per dataset to evaluate

    def evaluate(self, ood_tester, epoch=None, prefix=None):
        if prefix is not None:
            prefix = f'{prefix}_'
        else:
            prefix = ''
        results = {}
        metric_results = defaultdict(list)
        with torch.no_grad():
            in_scores = self.compute_scores(self.in_dataset, ood_tester, num_samples=self.num_max_samples)
            in_mean_score = torch.mean(in_scores, dim=0).detach().item()
            results[f'{prefix}_mean_id_score'] = in_mean_score
            logger.info(f'{prefix}_mean_id_score: {in_mean_score}')
            for out_dataset in self.out_datasets:
                out_scores = self.compute_scores(out_dataset, ood_tester, num_samples=self.num_max_samples)
                for metric_name, metric in self.metrics.items():
                    metric_result = metric(in_scores, out_scores)
                    qualifier = f'{prefix}{metric_name}.{out_dataset}'
                    results[qualifier] = metric_result
                    metric_results[metric_name].append(metric_result)
                    logger.info(f'{qualifier}: {metric_result}')
                    
        # Add mean results
        for metric_name, metric_result in metric_results.items():
            qualifier = f'{prefix}{metric_name}.mean'
            results[qualifier] = np.mean(metric_result)
            

        if self.logger:
            self.logger.log(results, epoch=epoch)

        return results

    def compute_scores(self, dataset, score_fn, num_samples=10000) -> Dict[float, dict]:
        total_samples = 0
        scores = []
        dataset = load_dataset_embeddings(dataset, self.embedding_type, self.config)
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        for embeddings, input_ids, masks, texts in loader:
            embeddings = embeddings.to(self.device).float()
            masks = masks.to(self.device)
            # set masked embeddings to 0
            embeddings = torch.where(masks.unsqueeze(-1).bool(), embeddings, torch.zeros_like(embeddings))
            if total_samples >= num_samples:
                break
            total_samples += len(embeddings)
            score = score_fn(embeddings, input_ids, masks, texts)
            scores.append(score)
        return torch.concat(scores, dim=0)[:num_samples]
