import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import itertools
from sklearn.model_selection import KFold

from typing import Dict

from lm_polygraph.estimators.estimator import Estimator
from lm_polygraph.generation_metrics.aggregated_metric import AggregatedMetric

from lm_polygraph.estimators.mahalanobis_distance import (
    compute_inv_covariance,
    mahalanobis_distance_with_known_centroids_sigma_inv,
    MahalanobisDistanceSeq,
    create_cuda_tensor_from_numpy,
    JITTERS
)
from lm_polygraph.generation_metrics.openai_fact_check import OpenAIFactCheck
from sklearn.metrics import mean_squared_error, roc_auc_score
from torch.nn.utils.rnn import pad_sequence
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from transformers import set_seed

import numpy as np

from typing import List

from lm_polygraph.ue_metrics.ue_metric import UEMetric, normalize
from lm_polygraph.ue_metrics import PredictionRejectionArea
from .compute_metrics import compute_processed_metrics

class AttentionPooling(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.attn = nn.Linear(embedding_size, 1)

    def forward(self, x, mask=None):
        attn_logits = self.attn(x)
        if mask is not None:
            attn_logits[mask] = -float('inf')
        attn_weights = torch.softmax(attn_logits, dim=1)
        x = x * attn_weights
        x = x.sum(dim=1)
        return x

class MLP_NN(nn.Module):
    def __init__(self, n_features: int = 4096):
        super().__init__()
        self.pooling = AttentionPooling(n_features)
        self.output = nn.Linear(n_features, 2)
        self.activation = nn.Softmax(dim=1)

    def forward(self, x, mask, eval=False):
        x = self.pooling(x, mask)
        x = self.output(x)
        if eval:
            return self.activation(x)[:, 1]
        return x
    
class MLP_NN_ens(nn.Module):
    def __init__(self, n_features: int = 32):
        super().__init__()
        self.output = nn.Linear(n_features, 2)
        self.activation = nn.Softmax(dim=1)

    def forward(self, x, mask=None, eval=False):
        x = self.output(x)
        if eval:
            return self.activation(x)[:, 1]
        return x
        
class MLP:
    def __init__(self, 
                 n_epochs: int = 5,
                 batch_size: int = 64,
                 lr: float = 1e-3,
                 weight_decay: float = 1e-5,
                 n_features: int = 4096, 
                 ensemble: bool = False,
                ):
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        if ensemble:
            self.model = MLP_NN_ens(n_features)
        else:
            self.model = MLP_NN(n_features)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.device = "cuda"

    def fit(self, X, y, mask):
        self.loss = nn.CrossEntropyLoss().to(self.device)
        
        self.model.train()
        if not isinstance(X, torch.Tensor):
            X_torch = torch.tensor(X, dtype=torch.float32)
        else:
            X_torch = X.clone().detach().float()
        if not isinstance(y, torch.Tensor):
            y_torch = torch.tensor(y, dtype=torch.float32).long()
        else:
            y_torch = y.clone().detach().long()
        if not isinstance(mask, torch.Tensor):
            mask_torch = torch.tensor(mask, dtype=torch.float32).bool().unsqueeze(2)
        else:
            mask_torch = mask.clone().detach().bool().unsqueeze(2)
        batch_start = torch.arange(0, len(X), self.batch_size)
        self.model.to(self.device)
        for epoch in range(self.n_epochs):
            for start in batch_start:
                X_batch = X_torch[start:start+self.batch_size].to(self.device)
                y_batch = y_torch[start:start+self.batch_size].to(self.device)
                mask_batch = mask_torch[start:start+self.batch_size].to(self.device)
                y_pred = self.model(X_batch, mask_batch)
                loss = self.loss(y_pred, y_batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
    def predict(self, X, mask):
        if not isinstance(X, torch.Tensor):
            X_torch = torch.tensor(X, dtype=torch.float32)
        else:
            X_torch = X.clone().detach().float()
            
        if not isinstance(mask, torch.Tensor):
            mask_torch = torch.tensor(mask, dtype=torch.float32).bool().unsqueeze(2)
        else:
            mask_torch = mask.clone().detach().bool().unsqueeze(2)
            
        batch_start = torch.arange(0, len(X), self.batch_size)
        self.model.eval()
        prediction = []
        if next(self.model.parameters()).device.type != self.device:
            self.model.to(self.device)
        for start in batch_start:
            X_batch = X_torch[start:start+self.batch_size].to(self.device)
            mask_batch = mask_torch[start:start+self.batch_size].to(self.device)
            y_pred = self.model(X_batch, mask_batch, eval=True)
            prediction.append(y_pred.cpu().detach().flatten())
        prediction = np.concatenate(prediction)
        return prediction

def cross_val_hp(X, y, attention_mask, model_init, params):
    best_score = -np.inf
    metric = roc_auc_score
        
    best_params = None
    for param in tqdm(itertools.product(*params.values())):
        model = model_init(param)
        scores_cv = []
        for i, (train, val) in enumerate(KFold(n_splits=5, random_state=1, shuffle=True).split(list(range(len(X))))):

            X_train = X[train]
            X_val = X[val]
        
            y_train = y[train]
            y_val = y[val]
            
            attention_mask_train = attention_mask[train]
            attention_mask_val = attention_mask[val]

            model.fit(X_train, y_train, attention_mask_train)
            try:
                y_pred = model.predict(X_val, attention_mask_val)
                scores_cv.append(metric(y_val, y_pred))
            except Exception as e: 
                print(f"Skip fold {i} with error: {e}")

        if len(scores_cv):
            scores_mean = np.mean(scores_cv)
        else:
            scores_mean = -np.inf
            
        if best_score < scores_mean:
            best_score = scores_mean
            best_params = param
    print("Sheeps BEST:", best_params, "BEST SCORE:", scores_mean)
    if best_params is None:
       best_params = list(itertools.product(*params.values()))[0]
    return best_params

class LayerSheeps(Estimator):
    def __init__(
        self,
        embeddings_type: str = "decoder",
        parameters_path: str = None,
        normalize: bool = False,
        aggregation: str = "mean",
        hidden_layer: int = -1,
        metric = None,
        metric_name: str = "",
        aggregated: bool = False,
        device: str = "cuda",
        cv_hp: bool = False,
        metric_thr: float = 0.3,
    ):
        self.hidden_layer = hidden_layer
        if self.hidden_layer == -1:
            super().__init__(["token_embeddings", "train_token_embeddings", "train_greedy_tokens", "train_target_texts"], "sequence")
        else:
            super().__init__([f"token_embeddings_{self.hidden_layer}", f"train_token_embeddings_{self.hidden_layer}", "train_greedy_tokens", "train_target_texts"], "sequence")
        self.centroid = None
        self.sigma_inv = None
        self.parameters_path = parameters_path
        self.embeddings_type = embeddings_type
        self.normalize = normalize
        self.min = 1e100
        self.max = -1e100
        self.is_fitted = False
        self.aggregation = aggregation
        self.metric_name = metric_name
        self.device = device
        self.cv_hp = cv_hp
        self.ue_predictor = MLP() 
        self.metric_thr = metric_thr       
        self.params = {
                "n_epochs": [5, 10, 20],
                "batch_size": [32, 64],
                "lr": [1e-1, 1e-2, 1e-3],
                "n_features": [4096],
        }
        self.model_init = lambda param: MLP(n_epochs=param[0],
                                            batch_size=param[1],
                                            lr=param[2],
                                            n_features=param[3])
        self.aggregated = aggregated
        if metric is not None:
            self.metric = metric

    def __str__(self):
        cv = "cv, " if self.cv_hp else ""
        return f"LayerSheeps_{self.embeddings_type} ({cv}{self.metric_name}, {self.hidden_layer})"

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        # take the embeddings
        if self.hidden_layer == -1:
            hidden_layer = ""
        else:
            hidden_layer = f"_{self.hidden_layer}"
            
        # compute centroids if not given
        if not self.is_fitted:
            train_greedy_texts = stats[f"train_greedy_texts"]
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_target_texts = stats[f"train_target_texts"]
            train_input_texts = stats[f"train_input_texts"]
            
            metric_key = f"train_seq_{self.metric_name}_{len(train_greedy_texts)}"
            
            if metric_key in stats.keys():
                self.train_seq_metrics = stats[metric_key]
            else:   
                seq_metrics = compute_processed_metrics(self.metric, train_greedy_texts, train_target_texts, train_input_texts, self.aggregated)
                stats[metric_key] = seq_metrics
                self.train_seq_metrics = seq_metrics
                
            self.train_seq_metrics = (self.train_seq_metrics < self.metric_thr).astype(int)

            train_embeddings = stats[f"train_token_embeddings_{self.embeddings_type}{hidden_layer}"]
            k = 0
            aggregated_embeddings = []
            lens = []
            for tokens in train_greedy_tokens:
                aggregated_embeddings.append(torch.tensor(np.array(train_embeddings[k:k+len(tokens)])))
                lens.append(len(tokens))
                k += len(tokens)
            aggregated_embeddings = pad_sequence(aggregated_embeddings, batch_first=True, padding_value=0)
            attention_mask = np.zeros((aggregated_embeddings.shape[0], aggregated_embeddings.shape[1]))

            for i, l in enumerate(lens):
                attention_mask[i, l:] = 1
                
            attention_mask = torch.tensor(attention_mask).int()
            
            if self.cv_hp:
                self.params["n_features"] = [aggregated_embeddings.shape[-1]]
                best_params = cross_val_hp(aggregated_embeddings, self.train_seq_metrics, attention_mask, self.model_init, self.params)
                self.ue_predictor = self.model_init(best_params)
            else:
                self.ue_predictor = MLP(n_features=aggregated_embeddings.shape[-1])
                
            self.ue_predictor.fit(aggregated_embeddings, self.train_seq_metrics, attention_mask)
            self.is_fitted = True
                
        k = 0
        embeddings = stats[f"token_embeddings_{self.embeddings_type}{hidden_layer}"]
        greedy_tokens = stats[f"greedy_tokens"]
        
        lens = []
        aggregated_embeddings = []
        for tokens in greedy_tokens:
            aggregated_embeddings.append(torch.tensor(np.array(embeddings[k:k+len(tokens)])))
            lens.append(len(tokens))
            k += len(tokens)
        aggregated_embeddings = pad_sequence(aggregated_embeddings, batch_first=True, padding_value=0)
                    
        attention_mask = np.zeros((aggregated_embeddings.shape[0], aggregated_embeddings.shape[1]))

        for i, l in enumerate(lens):
            attention_mask[i, l:] = 1
            
        attention_mask = torch.tensor(attention_mask).int()
        
        ue = self.ue_predictor.predict(aggregated_embeddings, attention_mask)

        return ue
    
class Sheeps(Estimator):
    def __init__(
        self,
        embeddings_type: str = "decoder",
        parameters_path: str = None,
        normalize: bool = False,
        aggregation: str = "mean",
        hidden_layers: int = -1,
        metric = None,
        metric_name: str = "",
        aggregated: bool = False,
        device: str = "cuda",
        cv_hp: bool = False,
        metric_thr: float = 0.3,
    ):
        self.hidden_layers = hidden_layers
        
        self.sheeps = []
        dependencies = ["train_greedy_tokens", "train_target_texts"]
        for layer in self.hidden_layers:
            if layer == -1:
                dependencies += ["token_embeddings", "train_token_embeddings"]
            else:
                dependencies += [f"token_embeddings_{layer}", f"train_token_embeddings_{layer}"]

            self.sheeps.append(LayerSheeps(embeddings_type, parameters_path=parameters_path, metric=metric, metric_name=metric_name, 
                                           aggregated=aggregated, hidden_layer=layer, device=device, cv_hp=cv_hp, metric_thr=metric_thr))
            
        super().__init__(dependencies, "sequence")
        self.parameters_path = parameters_path
        self.embeddings_type = embeddings_type
        self.normalize = normalize
        self.min = 1e100
        self.max = -1e100
        self.is_fitted = False
        self.aggregation = aggregation
        self.metric_name = metric_name
        self.device = device
        self.cv_hp = cv_hp
        self.ue_predictor = LogisticRegression() 
        self.metric_thr = metric_thr       
        self.aggregated = aggregated
        if metric is not None:
            self.metric = metric

    def __str__(self):
        cv = "cv, " if self.cv_hp else ""
        return f"Sheeps_{self.embeddings_type} ({cv}{self.metric_name})"

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
                   
        # compute centroids if not given
        if not self.is_fitted:
            set_seed(42)
            train_greedy_texts = stats[f"train_greedy_texts"]
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_target_texts = stats[f"train_target_texts"]
            train_input_texts = stats[f"train_input_texts"]
            
            metric_key = f"train_seq_{self.metric_name}_{len(train_greedy_texts)}"
            if metric_key in stats.keys():
                self.train_seq_metrics = stats[metric_key]
            else:   
                seq_metrics = compute_processed_metrics(self.metric, train_greedy_texts, train_target_texts, train_input_texts, self.aggregated)
                stats[metric_key] = seq_metrics
                self.train_seq_metrics = seq_metrics
                
            train_seq_metrics_orig = self.train_seq_metrics
            self.train_seq_metrics = (self.train_seq_metrics < self.metric_thr).astype(int)
                                    
            train_sheeps = []
            dev_size = 0.5 
            train_idx, dev_idx = train_test_split(list(range(len(train_greedy_texts))), test_size=dev_size, random_state=42)
                
            for layer in self.hidden_layers: 
                if layer == -1:
                    hidden_layer = ""
                else:
                    hidden_layer = f"_{layer}"
                train_embeddings = stats[f"train_token_embeddings_{self.embeddings_type}{hidden_layer}"]
                k = 0
                aggregated_embeddings = []
                lens = []
                for tokens in train_greedy_tokens:
                    aggregated_embeddings.append(train_embeddings[k:k+len(tokens)])
                    lens.append(len(tokens))
                    k += len(tokens)                      
                if layer == -1:
                    train_embeddings = stats[f"train_embeddings_{self.embeddings_type}"]
                    train_stats = {"train_greedy_tokens": [train_greedy_tokens[k] for k in train_idx], 
                                   "train_input_texts": [train_input_texts[k] for k in train_idx],
                                   "greedy_tokens": [train_greedy_tokens[k] for k in dev_idx], 
                                   "train_greedy_texts": [train_greedy_texts[k] for k in train_idx],
                                   "train_target_texts": [train_target_texts[k] for k in train_idx],
                                   f"train_token_embeddings_{self.embeddings_type}": [emb for k in train_idx for emb in aggregated_embeddings[k]],
                                   f"token_embeddings_{self.embeddings_type}": [emb for k in dev_idx for emb in aggregated_embeddings[k]],
                                   f"train_seq_{self.metric_name}_{len(train_idx)}": train_seq_metrics_orig[train_idx],
                                  }                
                else:
                    train_embeddings = stats[f"train_embeddings_{self.embeddings_type}_{layer}"]
                    train_stats = {"train_greedy_tokens": [train_greedy_tokens[k] for k in train_idx], 
                                   "train_input_texts": [train_input_texts[k] for k in train_idx],
                                   "greedy_tokens": [train_greedy_tokens[k] for k in dev_idx], 
                                   "train_greedy_texts": [train_greedy_texts[k] for k in train_idx],
                                   "train_target_texts": [train_target_texts[k] for k in train_idx],
                                   f"train_token_embeddings_{self.embeddings_type}_{layer}": [emb for k in train_idx for emb in aggregated_embeddings[k]],
                                   f"token_embeddings_{self.embeddings_type}_{layer}": [emb for k in dev_idx for emb in aggregated_embeddings[k]],
                                   f"train_seq_{self.metric_name}_{len(train_idx)}": train_seq_metrics_orig[train_idx],
                                  }
                score = self.sheeps[layer](train_stats).reshape(-1)
                train_sheeps.append(score)
            train_sheeps = np.array(train_sheeps).T
            self.ue_predictor.fit(train_sheeps, self.train_seq_metrics[dev_idx])
            
            prr = PredictionRejectionArea()
            self.best_layer = -1
            best_prr = 0
            for layer in self.hidden_layers:
                prr_ = prr(train_sheeps[:, layer], 1 - self.train_seq_metrics[dev_idx])
                if prr_ > best_prr:
                    best_prr = prr_
                    self.best_layer = layer
                    print("BEST layer:", layer, prr_)
                
            self.is_fitted = True
                
        eval_scores = []
        for layer in self.hidden_layers:
            score = self.sheeps[layer](stats).reshape(-1)
            eval_scores.append(score)
        eval_scores = np.array(eval_scores).T
        eval_scores[np.isnan(eval_scores)] = 0
        ue = self.ue_predictor.predict_proba(eval_scores)[:, 1]
        
        return ue
    
    
class LayerSheepsClaim(Estimator):
    def __init__(
        self,
        hidden_layer: int = -1,
        embeddings_type: str = "decoder",
        cv_hp: bool = True,
    ):
        self.cv_hp = cv_hp
        self.hidden_layer = hidden_layer
        self.embeddings_type = embeddings_type
        if self.hidden_layer == -1:
            super().__init__(["token_embeddings", "train_token_embeddings", "train_greedy_tokens", "train_target_texts", "train_claims", "claims"], "claim")
        else:
            super().__init__([f"token_embeddings_{self.hidden_layer}", f"train_token_embeddings_{self.hidden_layer}", "train_greedy_tokens", "train_target_texts", "train_claims", "claims"], "claim")
        
        self.ue_predictor = MLP() 
        self.params = {
                "n_epochs": [5, 10, 20],
                "batch_size": [32, 64],
                "lr": [1e-1, 1e-2, 1e-3],
                "n_features": [4096],
        }
        self.model_init = lambda param: MLP(n_epochs=param[0],
                                            batch_size=param[1],
                                            lr=param[2],
                                            n_features=param[3])
        self.is_fitted = False
        self.factcheck = OpenAIFactCheck(openai_model="gpt-4o-mini")
        
    def __str__(self):
        return f"LayerSheepsClaim ({self.hidden_layer})"
    
    def _get_targets(self, greedy_tokens, claims, factcheck):
        targets = []
        for j in range(len(greedy_tokens)):
            target = np.zeros_like(greedy_tokens[j]) + 1.0
            true_tokens = []
            false_tokens = []
            for i, claim in enumerate(claims[j]):
                if not np.isnan(factcheck[j][i]):
                    for t in claim.aligned_token_ids:
                         if factcheck[j][i] == 1:
                             false_tokens.append(t)
                         else:
                             true_tokens.append(t)
            final_true_tokens = np.array(list(set(true_tokens) - set(false_tokens)))
            final_false_tokens = np.array(list(set(false_tokens) - set(true_tokens)))
            if len(final_true_tokens):
                target[final_true_tokens] = 1.0
            if len(final_false_tokens):
                target[final_false_tokens] = 0.0
            target = np.clip(target, 0, 1)
            targets.append(target)
        return targets

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        # take the embeddings
        if self.hidden_layer == -1:
            hidden_layer = ""
        else:
            hidden_layer = f"_{self.hidden_layer}"
            
        # compute centroids if not given
        if not self.is_fitted:
            train_greedy_texts = stats[f"train_greedy_texts"]
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_target_texts = stats[f"train_target_texts"]
            train_input_texts = stats[f"train_input_texts"]
            train_claims = stats[f"train_claims"]
            train_stats = {"claims": train_claims, "input_texts": train_input_texts}
            
            if "factcheck" in stats.keys():
                self.factcheck_score = stats["factcheck"]
                self.train_token_metrics = np.concatenate(self._get_targets(train_greedy_tokens, train_claims, self.factcheck_score))
            else:
                self.factcheck_score = self.factcheck(train_stats, None)
                self.train_token_metrics = np.concatenate(self._get_targets(train_greedy_tokens, train_claims, self.factcheck_score))
                stats["train_token_metrics"] = self.train_token_metrics                
                stats["factcheck"] = self.factcheck_score
                
            target = np.concatenate(self.factcheck_score)    
            target[np.isnan(target)] = 0
                
            train_embeddings = np.array(stats[f"train_token_embeddings_{self.embeddings_type}{hidden_layer}"])
            k = 0
            aggregated_embeddings = []
            lens = []
            for claims, tokens in zip(train_claims, train_greedy_tokens):
                for claim in claims:
                    aggregated_embeddings.append(torch.tensor(np.array(train_embeddings[k:k+len(tokens)][np.array(claim.aligned_token_ids)])))
                    lens.append(len(claim.aligned_token_ids))
                k += len(tokens)

            aggregated_embeddings = pad_sequence(aggregated_embeddings, batch_first=True, padding_value=0)
            attention_mask = np.zeros((aggregated_embeddings.shape[0], aggregated_embeddings.shape[1]))

            for i, l in enumerate(lens):
                attention_mask[i, l:] = 1
                
            attention_mask = torch.tensor(attention_mask).int()
            
            if self.cv_hp:
                self.params["n_features"] = [aggregated_embeddings.shape[-1]]
                best_params = cross_val_hp(aggregated_embeddings, target, attention_mask, self.model_init, self.params)
                self.ue_predictor = self.model_init(best_params)
            else:
                self.ue_predictor = MLP(n_features=aggregated_embeddings.shape[-1])
                
            self.ue_predictor.fit(aggregated_embeddings, target, attention_mask)
            self.is_fitted = True
                
        k = 0
        embeddings = np.array(stats[f"token_embeddings_{self.embeddings_type}{hidden_layer}"])
        greedy_tokens = stats[f"greedy_tokens"]
        claims = stats[f"claims"]
        
        lens = []
        aggregated_embeddings = []
        for claims_, tokens in zip(claims, greedy_tokens):
            for claim in claims_:
                aggregated_embeddings.append(torch.tensor(np.array(embeddings[k:k+len(tokens)][np.array(claim.aligned_token_ids)])))
                lens.append(len(claim.aligned_token_ids))
            k += len(tokens)
            
        aggregated_embeddings = pad_sequence(aggregated_embeddings, batch_first=True, padding_value=0)
                    
        attention_mask = np.zeros((aggregated_embeddings.shape[0], aggregated_embeddings.shape[1]))

        for i, l in enumerate(lens):
            attention_mask[i, l:] = 1
            
        attention_mask = torch.tensor(attention_mask).int()
        ue = self.ue_predictor.predict(aggregated_embeddings, attention_mask)
        sheeps_scores = []
        k = 0
        for idx, tokens in enumerate(greedy_tokens):
            sheeps_scores.append([])
            for _ in claims[idx]:
                sheeps_scores[-1].append(ue[k])
                k += 1

        return sheeps_scores
    
    
class SheepsClaim(Estimator):
    def __init__(
        self,
        hidden_layers: int = -1,
        embeddings_type: str = "decoder",
    ):
        self.embeddings_type = embeddings_type
        self.hidden_layers = hidden_layers
        self.sheeps = []
        dependencies = ["train_greedy_tokens", "train_target_texts", "train_claims", "claims"]
        for layer in self.hidden_layers:
            if layer == -1:
                dependencies += ["token_embeddings", "train_token_embeddings"]
            else:
                dependencies += [f"token_embeddings_{layer}", f"train_token_embeddings_{layer}"]

            self.sheeps.append(LayerSheepsClaim(hidden_layer=layer))
            
        super().__init__(dependencies, "claim")
        self.ue_predictor = LogisticRegression() 
        self.factcheck = OpenAIFactCheck(openai_model="gpt-4o-mini")
        self.is_fitted = False

    def __str__(self):
        return f"SheepsClaim"

    def _get_targets(self, greedy_tokens, claims, factcheck):
        targets = []
        for j in range(len(greedy_tokens)):
            target = np.zeros_like(greedy_tokens[j]) + 1.0
            true_tokens = []
            false_tokens = []
            for i, claim in enumerate(claims[j]):
                if not np.isnan(factcheck[j][i]):
                    for t in claim.aligned_token_ids:
                         if factcheck[j][i] == 1:
                             false_tokens.append(t)
                         else:
                             true_tokens.append(t)
            final_true_tokens = np.array(list(set(true_tokens) - set(false_tokens)))
            final_false_tokens = np.array(list(set(false_tokens) - set(true_tokens)))
            if len(final_true_tokens):
                target[final_true_tokens] = 1.0
            if len(final_false_tokens):
                target[final_false_tokens] = 0.0
            target = np.clip(target, 0, 1)
            targets.append(target)
        return targets

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
                   
        # compute centroids if not given
        if not self.is_fitted:
            set_seed(42)
            train_greedy_texts = stats[f"train_greedy_texts"]
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_target_texts = stats[f"train_target_texts"]
            train_input_texts = stats[f"train_input_texts"]
            train_claims = stats[f"train_claims"]
            train_stats = {"claims": train_claims, "input_texts": train_input_texts}
            
            if "factcheck" in stats.keys():
                self.factcheck_score = stats["factcheck"]
                self.train_token_metrics = np.concatenate(self._get_targets(train_greedy_tokens, train_claims, self.factcheck_score))
            else:
                self.factcheck_score = self.factcheck(train_stats, None)
                self.train_token_metrics = np.concatenate(self._get_targets(train_greedy_tokens, train_claims, self.factcheck_score))
                stats["train_token_metrics"] = self.train_token_metrics                
                stats["factcheck"] = self.factcheck_score
                                    
            train_sheeps = []
            dev_size = 0.5 
            train_idx, dev_idx = train_test_split(list(range(len(train_greedy_texts))), test_size=dev_size, random_state=42)
                
            for layer in self.hidden_layers: 
                if layer == -1:
                    hidden_layer = ""
                else:
                    hidden_layer = f"_{layer}"
                train_embeddings = np.array(stats[f"train_token_embeddings_{self.embeddings_type}{hidden_layer}"])
                k = 0
                aggregated_embeddings = []
                lens = []
                for tokens in train_greedy_tokens:
                    aggregated_embeddings.append(train_embeddings[k:k+len(tokens)])
                    lens.append(len(tokens))
                    k += len(tokens)                      
                if layer == -1:
                    train_embeddings = stats[f"train_embeddings_{self.embeddings_type}"]
                    train_stats = {"train_greedy_tokens": [train_greedy_tokens[k] for k in train_idx], 
                                   "train_input_texts": [train_input_texts[k] for k in train_idx],
                                   "greedy_tokens": [train_greedy_tokens[k] for k in dev_idx], 
                                   "train_greedy_texts": [train_greedy_texts[k] for k in train_idx],
                                   "train_target_texts": [train_target_texts[k] for k in train_idx],
                                   f"train_token_embeddings_{self.embeddings_type}": [emb for k in train_idx for emb in aggregated_embeddings[k]],
                                   f"token_embeddings_{self.embeddings_type}": [emb for k in dev_idx for emb in aggregated_embeddings[k]],
                                   "claims": [train_claims[k] for k in dev_idx],
                                   "train_claims": [train_claims[k] for k in train_idx],
                                  }                
                else:
                    train_embeddings = stats[f"train_embeddings_{self.embeddings_type}_{layer}"]
                    train_stats = {"train_greedy_tokens": [train_greedy_tokens[k] for k in train_idx], 
                                   "train_input_texts": [train_input_texts[k] for k in train_idx],
                                   "greedy_tokens": [train_greedy_tokens[k] for k in dev_idx], 
                                   "train_greedy_texts": [train_greedy_texts[k] for k in train_idx],
                                   "train_target_texts": [train_target_texts[k] for k in train_idx],
                                   f"train_token_embeddings_{self.embeddings_type}_{layer}": [emb for k in train_idx for emb in aggregated_embeddings[k]],
                                   f"token_embeddings_{self.embeddings_type}_{layer}": [emb for k in dev_idx for emb in aggregated_embeddings[k]],
                                   "claims": [train_claims[k] for k in dev_idx],
                                   "train_claims": [train_claims[k] for k in train_idx],
                                  }
                    
                score = self.sheeps[layer](train_stats)
                train_sheeps.append(np.concatenate(score))
                
            target = np.concatenate([self.factcheck_score[k] for k in dev_idx])
            target[np.isnan(target)] = 0
            train_sheeps = np.array(train_sheeps).T
            self.ue_predictor.fit(train_sheeps, target)                
            self.is_fitted = True
                
        eval_scores = []
        for layer in self.hidden_layers:
            score = self.sheeps[layer](stats)
            eval_scores.append(np.concatenate(score))
        eval_scores = np.array(eval_scores).T
        eval_scores[np.isnan(eval_scores)] = 0
        ue = self.ue_predictor.predict_proba(eval_scores)[:, 1]
        
        claims = stats[f"claims"]
        sheeps_scores = []
        k = 0
        for idx, tokens in enumerate(stats["greedy_tokens"]):
            sheeps_scores.append([])
            for _ in claims[idx]:
                sheeps_scores[-1].append(ue[k])
                k += 1
        
        return sheeps_scores