import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import itertools
from sklearn.model_selection import KFold

from typing import Dict

from lm_polygraph.estimators.estimator import Estimator
from lm_polygraph.generation_metrics.aggregated_metric import AggregatedMetric

from lm_polygraph.estimators.mahalanobis_distance import (
    compute_inv_covariance,
    mahalanobis_distance_with_known_centroids_sigma_inv,
    MahalanobisDistanceSeq,
    create_cuda_tensor_from_numpy,
    JITTERS
)
from lm_polygraph.generation_metrics.openai_fact_check import OpenAIFactCheck
from sklearn.metrics import mean_squared_error, roc_auc_score
from .compute_metrics import compute_processed_metrics

class MLP_NN(nn.Module):
    def __init__(self, n_features: int = 4096):
        super().__init__()
        self.layers = nn.ModuleList([nn.Dropout(0.2),
                                     nn.Linear(n_features, 256),
                                     nn.ReLU(),
                                     nn.Linear(256, 128),
                                     nn.ReLU(),
                                     nn.Linear(128, 64),
                                     nn.ReLU(),
                                     nn.Linear(64, 2)])
        self.activation = nn.Softmax(dim=1)

    def forward(self, x, eval=False):
        for layer in self.layers:
            x = layer(x)
        if eval:
            return self.activation(x)[:, 1]
        return x
        
class MLP:
    def __init__(self, 
                 n_epochs: int = 20,
                 batch_size: int = 32,
                 lr: float = 5e-4,
                 weight_decay: float = 1e-5,
                 n_features: int = 4096, 
                ):
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.model = MLP_NN(n_features)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.device = "cuda"

    def fit(self, X, y):
        nSamples = [(y==0).sum(), (y==1).sum()]
        normedWeights = [1 - (x / sum(nSamples)) for x in nSamples]
        normedWeights = torch.FloatTensor(normedWeights).to(self.device)
        self.loss = nn.CrossEntropyLoss(weight=normedWeights).to(self.device)
        
        self.model.train()
        if not isinstance(X, torch.Tensor):
            X_torch = torch.tensor(X, dtype=torch.float32)
        else:
            X_torch = X.clone().detach().float()
        if not isinstance(y, torch.Tensor):
            y_torch = torch.tensor(y, dtype=torch.float32).long()
        else:
            y_torch = y_torch.clone().detach().long()
        batch_start = torch.arange(0, len(X), self.batch_size)
        self.model.to(self.device)
        for epoch in range(self.n_epochs):
            for start in batch_start:
                X_batch = X_torch[start:start+self.batch_size].to(self.device)
                y_batch = y_torch[start:start+self.batch_size].to(self.device)
                y_pred = self.model(X_batch)
                loss = self.loss(y_pred, y_batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
    def predict(self, X):
        if not isinstance(X, torch.Tensor):
            X_torch = torch.tensor(X, dtype=torch.float32)
        else:
            X_torch = X.clone().detach().float()
        batch_start = torch.arange(0, len(X), self.batch_size)
        self.model.eval()
        prediction = []
        if next(self.model.parameters()).device.type != self.device:
            self.model.to(self.device)
        for start in batch_start:
            X_batch = X_torch[start:start+self.batch_size].to(self.device)
            y_pred = self.model(X_batch, eval=True)
            prediction.append(y_pred.cpu().detach().flatten())
        prediction = np.concatenate(prediction)
        return prediction

def cross_val_hp(X, y, model_init, params):
    best_score = -np.inf
    metric = roc_auc_score
        
    best_params = None
    for param in tqdm(itertools.product(*params.values())):
        model = model_init(param)
        scores_cv = []
        for i, (train, val) in enumerate(KFold(n_splits=5, random_state=1, shuffle=True).split(list(range(len(X))))):

            X_train = X[train]
            X_val = X[val]
        
            y_train = y[train]
            y_val = y[val]
        
            model.fit(X_train, y_train)
            try:
                scores_cv.append(metric(y_val, model.predict(X_val)))
            except Exception as e: 
                print(f"Skip fold {i} with error: {e}")

        if len(scores_cv):
            scores_mean = np.mean(scores_cv)
        else:
            scores_mean = -np.inf
            
        if best_score < scores_mean:
            best_score = scores_mean
            best_params = param
    print("BEST:", best_params, "BEST SCORE:", scores_mean)
    if best_params is None:
       best_params = list(itertools.product(*params.values()))[0]
    return best_params


def aggregate_token_embeddings(embeddings):
    return (np.mean(embeddings, axis=0) + np.array(embeddings[-1])) / 2
    

class MIND(Estimator):
    def __init__(
        self,
        embeddings_type: str = "decoder",
        parameters_path: str = None,
        normalize: bool = False,
        aggregation: str = "mean",
        hidden_layer: int = -1,
        metric = None,
        metric_name: str = "",
        aggregated: bool = False,
        device: str = "cuda",
        cv_hp: bool = False,
        metric_thr: float = 0.3,
    ):
        self.hidden_layer = hidden_layer
        if self.hidden_layer == -1:
            super().__init__(["token_embeddings", "train_token_embeddings", "train_greedy_tokens", "train_target_texts"], "sequence")
        else:
            super().__init__([f"token_embeddings_{self.hidden_layer}", f"train_token_embeddings_{self.hidden_layer}", "train_greedy_tokens", "train_target_texts"], "sequence")
        self.centroid = None
        self.sigma_inv = None
        self.parameters_path = parameters_path
        self.embeddings_type = embeddings_type
        self.normalize = normalize
        self.min = 1e100
        self.max = -1e100
        self.is_fitted = False
        self.aggregation = aggregation
        self.metric_name = metric_name
        self.device = device
        self.cv_hp = cv_hp
        self.ue_predictor = MLP() 
        self.metric_thr = metric_thr       
        self.params = {
                "n_epochs": [10, 20],
                "batch_size": [32, 64],
                "lr": [1e-3, 1e-4, 5e-4, 5e-5, 1e-5, 5e-6],
                "n_features": [4096],
        }
        self.model_init = lambda param: MLP(n_epochs=param[0],
                                            batch_size=param[1],
                                            lr=param[2],
                                            n_features=param[3])
        self.aggregated = aggregated
        if metric is not None:
            self.metric = metric

    def __str__(self):
        cv = "cv, " if self.cv_hp else ""
        return f"MIND_{self.embeddings_type} ({cv}{self.metric_name}, {self.hidden_layer})"

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        # take the embeddings
        if self.hidden_layer == -1:
            hidden_layer = ""
        else:
            hidden_layer = f"_{self.hidden_layer}"
            
        # compute centroids if not given
        if not self.is_fitted:
            train_greedy_texts = stats[f"train_greedy_texts"]
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_target_texts = stats[f"train_target_texts"]
            train_input_texts = stats[f"train_input_texts"]
            
            metric_key = f"train_seq_{self.metric_name}_{len(train_greedy_texts)}"
            
            if metric_key in stats.keys():
                self.train_seq_metrics = stats[metric_key]
            else:   
                seq_metrics = compute_processed_metrics(self.metric, train_greedy_texts, train_target_texts, train_input_texts, self.aggregated)
                stats[metric_key] = seq_metrics
                self.train_seq_metrics = seq_metrics
            
            self.train_seq_metrics = (self.train_seq_metrics < self.metric_thr).astype(int)
            train_embeddings = stats[f"train_token_embeddings_{self.embeddings_type}{hidden_layer}"]
                    
            k = 0
            aggregated_embeddings = []
            for tokens in train_greedy_tokens:
                aggregated_embeddings.append(aggregate_token_embeddings(train_embeddings[k:k+len(tokens)]))
                k += len(tokens)
            aggregated_embeddings = np.array(aggregated_embeddings)        
            
            if self.cv_hp:
                self.params["n_features"] = [aggregated_embeddings.shape[-1]]
                best_params = cross_val_hp(aggregated_embeddings, self.train_seq_metrics, self.model_init, self.params)
                self.ue_predictor = self.model_init(best_params)
            else:
                self.ue_predictor = MLP(n_features=aggregated_embeddings.shape[-1])
                
            self.ue_predictor.fit(aggregated_embeddings, self.train_seq_metrics)
            self.is_fitted = True
                
        k = 0
        embeddings = stats[f"token_embeddings_{self.embeddings_type}{hidden_layer}"]
        greedy_tokens = stats[f"greedy_tokens"]
        aggregated_embeddings = []
        for tokens in greedy_tokens:
            aggregated_embeddings.append(aggregate_token_embeddings(embeddings[k:k+len(tokens)]))
            k += len(tokens)
        aggregated_embeddings = np.array(aggregated_embeddings)        
        ue = self.ue_predictor.predict(aggregated_embeddings)

        return ue