import os
import numpy as np
import torch
import itertools
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, roc_auc_score

from typing import Dict
import json

from lm_polygraph.estimators.estimator import Estimator
from sklearn.linear_model import LogisticRegression

from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from catboost import CatBoostRegressor
from sklearn.linear_model import Ridge
from lm_polygraph.generation_metrics.alignscore import AlignScore
from lm_polygraph.generation_metrics.aggregated_metric import AggregatedMetric
from lm_polygraph.ue_metrics import PredictionRejectionArea
from lm_polygraph.generation_metrics.openai_fact_check import OpenAIFactCheck
from .compute_metrics import compute_processed_metrics


class LookBackLens(Estimator):
    def __init__(
        self,
        metric = None,
        metric_name = "AlignScore",
        threshold: float = 0.3,
        aggregated: bool = False,
    ):
        super().__init__(["lookback_ratios", "train_lookback_ratios", "train_greedy_texts", "train_target_texts"], "sequence")
        
        self.metric = metric            
        self.metric_name = metric_name
        
        if self.metric_name != "Accuracy":
            self.threshold = threshold

        self.classifier = LogisticRegression(max_iter=1000)
        self.is_fitted = False
        self.aggregated = aggregated

    def __str__(self):
        return f"LookBackLens ({self.metric_name})"

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        
        if not self.is_fitted:
            train_lookback_ratios = np.array(stats[f"train_lookback_ratios"])
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_greedy_texts = stats[f"train_greedy_texts"]
            train_target_texts = stats[f"train_target_texts"]
            train_input_texts = stats[f"train_input_texts"]
            
            metric_key = f"train_seq_{self.metric_name}_{len(train_greedy_texts)}"
            
            if metric_key in stats.keys():
                targets = stats[metric_key]
            else:   
                seq_metrics = compute_processed_metrics(self.metric, train_greedy_texts, train_target_texts, train_input_texts, self.aggregated)
                stats[metric_key] = seq_metrics
                targets = seq_metrics
                        
            if self.metric_name != "Accuracy":
                targets = (targets > self.threshold).astype(int)
                
            features = []
            k = 0
            for greedy_tokens in train_greedy_tokens:
                features.append(train_lookback_ratios[k:k+len(greedy_tokens)].mean(0))
                k += len(greedy_tokens)
            features = np.array(features)
            self.classifier.fit(features, targets)
            self.is_fitted = True

        lookback_ratios = np.array(stats[f"lookback_ratios"])
        greedy_tokens = np.array(stats[f"greedy_tokens"])
        features = []
        k = 0
        for greedy_token in greedy_tokens:
            features.append(lookback_ratios[k:k+len(greedy_token)].mean(0))
            k += len(greedy_token)
        
        uq = 1 - self.classifier.predict_proba(features)[:, 1]
        return uq
    
    
class LookBackLensClaim(Estimator):
    def __init__(
        self,
    ):
        super().__init__(["lookback_ratios", "train_lookback_ratios", "train_greedy_texts", "train_target_texts", "train_claims", "claims"], "claim")
        
        self.classifier = LogisticRegression(max_iter=1000)
        self.is_fitted = False
        self.factcheck = OpenAIFactCheck(openai_model="gpt-4o-mini")

    def __str__(self):
        return f"LookBackLensClaim"

    def _get_targets(self, greedy_tokens, claims, factcheck):
        targets = []
        for j in range(len(greedy_tokens)):
            target = np.zeros_like(greedy_tokens[j]) + 1.0
            true_tokens = []
            false_tokens = []
            for i, claim in enumerate(claims[j]):
                if not np.isnan(factcheck[j][i]):
                    for t in claim.aligned_token_ids:
                         if factcheck[j][i] == 1:
                             false_tokens.append(t)
                         else:
                             true_tokens.append(t)
            final_true_tokens = np.array(list(set(true_tokens) - set(false_tokens)))
            final_false_tokens = np.array(list(set(false_tokens) - set(true_tokens)))
            if len(final_true_tokens):
                target[final_true_tokens] = 1.0
            if len(final_false_tokens):
                target[final_false_tokens] = 0.0
            target = np.clip(target, 0, 1)
            targets.append(target)
        return targets

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        
        if not self.is_fitted:
            train_lookback_ratios = np.array(stats[f"train_lookback_ratios"])
            train_greedy_tokens = stats[f"train_greedy_tokens"]
            train_input_texts = stats[f"train_input_texts"]
            train_claims = stats[f"train_claims"]
            train_stats = {"claims": train_claims, "input_texts": train_input_texts}
            
            if "factcheck" in stats.keys():
                self.factcheck_score = stats["factcheck"]
                self.train_token_metrics = np.concatenate(self._get_targets(train_greedy_tokens, train_claims, self.factcheck_score))
            else:
                self.factcheck_score = self.factcheck(train_stats, None)
                self.train_token_metrics = np.concatenate(self._get_targets(train_greedy_tokens, train_claims, self.factcheck_score))
                stats["train_token_metrics"] = self.train_token_metrics                
                stats["factcheck"] = self.factcheck_score
                            
            target = np.concatenate(self.factcheck_score)    
            target[np.isnan(target)] = 0
            train_features_claims = []
            k = 0
            for claims, tokens in zip(train_claims, train_greedy_tokens):
                for claim in claims:
                    train_features_claims.append(train_lookback_ratios[k:k+len(tokens)][np.array(claim.aligned_token_ids)].mean(0))
                k += len(tokens)
            train_features_claims = torch.tensor(np.array(train_features_claims))
            self.classifier.fit(train_features_claims, target)
            self.is_fitted = True
            
        lookback_ratios = np.array(stats[f"lookback_ratios"])
        greedy_tokens = np.array(stats[f"greedy_tokens"])
        claims = stats[f"claims"]
        
        features_claims = []
        k = 0
        for seq_claim, tokens in zip(claims, stats["greedy_tokens"]):
            for claim in seq_claim:
                features_claims.append(lookback_ratios[k:k+len(tokens)][np.array(claim.aligned_token_ids)].mean(0))
            k+=len(tokens)
        features_claims = torch.tensor(np.array(features_claims))
        claim_ues = self.classifier.predict_proba(features_claims)[:, 1]
        
        lookbacklens_scores = []
        k = 0
        for idx, tokens in enumerate(greedy_tokens):
            lookbacklens_scores.append([])
            for _ in claims[idx]:
                lookbacklens_scores[-1].append(claim_ues[k])
                k += 1
            
        return lookbacklens_scores