import os
import logging
import numpy as np
from tqdm import tqdm
from typing import List

from collections import Counter

from scipy.special import kl_div
from sklearn.base import BaseEstimator
from sklearn.ensemble import ExtraTreesRegressor

from data.data_loader import load_results_data, load_judgment_data, load_score_data

logger = logging.getLogger("rich")

class Mean(BaseEstimator):
    """
    Mean score predictor.
    """
    def __init__(self, scale=9, bias=1):
        super().__init__()
        self.scale = scale
        self.bias = bias

    def fit(self, X, y):
        return self

    def predict(self, X):
        return np.nanmean(np.array(X), axis=1) * self.scale + self.bias

def score_hist(scores: list):
    """
    Compute the histogram of the given score distribution.
    """
    counts = Counter(scores)
    hist = np.array(list(counts.values()))
    return hist / hist.sum()

def kl_uniform_weight(X: np.ndarray, num_classes: int = 10):
    """
    Compute the KL-divergence based weight for the given score distribution.
    """
    uniform_dist = np.ones(num_classes) / num_classes
    max_kl = kl_div([1.0] + [0.0] * (num_classes - 1), uniform_dist).sum()
    hist = score_hist(X)
    padding = num_classes - hist.shape[0]
    assert padding >= 0
    return (max_kl - kl_div(np.pad(hist, (0, padding), 'constant'), uniform_dist).sum()) / max_kl

def fit_score(X: np.ndarray, y: np.ndarray):
    """
    Learn the supervised score predictor.
    """
    model = ExtraTreesRegressor(
        n_estimators=10,
        max_depth=2,
        random_state=42)
    model.fit(X, y)
    return model

def predict_score(model: BaseEstimator,  X: np.ndarray, weight: float = 0):
    """
    Predict the final score.
    """
    return model.predict(X) * (1 - weight) + Mean().predict(X) * weight

def create_score(
    dataset_name: str = "wildbench",
    judge: str = "gpt-4o",
    label_judge: str = "gpt-4o",
    train_model_names: List[str] | None = [],
    test_model_names: List[str] = ['Meta-Llama-3-8B-Instruct'],
    data_dir: str = "data/",
    task_id: str = "",
    add_output: bool = False
) -> None:
    """
    Create score file for the given dataset.
    """
    test_judgment = load_judgment_data(data_dir=data_dir, dataset_name=dataset_name, judge=judge, model_names=test_model_names).set_index(["session_id"])
    test_sample = test_judgment.groupby("session_id").agg(list).map(np.array)
    if train_model_names and len(train_model_names) > 0:
        train_judgment = load_judgment_data(data_dir=data_dir, dataset_name=dataset_name, judge=judge, model_names=train_model_names).set_index(["session_id", "model_test"])
        train_label = load_score_data(data_dir=data_dir, dataset_name=dataset_name, model_names=train_model_names, judge=label_judge).set_index(["session_id", "model_test"])
        
        train_sample = train_judgment.join(train_label, how="inner")
        train_sample = train_sample.loc[:, ["norm_probability", "score"]].groupby("session_id").agg(list).map(np.array)
        train_sample["scorer"] = train_sample.apply(lambda x: fit_score(x["norm_probability"], x["score"]), axis=1)
        train_sample["weight"] = train_sample["score"].apply(lambda x: kl_uniform_weight(x))
        test_sample = test_sample.join(train_sample.loc[:, ["scorer", "weight"]], how="inner")
    else:
        test_sample["scorer"] = Mean()
        test_sample["weight"] = 0

    test_sample["score"] = test_sample.apply(lambda x: predict_score(x["scorer"], x["norm_probability"], x["weight"]), axis=1)
    test_sample = test_sample.loc[:, ["model_test", "score"]].explode(["model_test", "score"]).reset_index().set_index(["session_id", "model_test"])

    if add_output:
        results = load_results_data(data_dir=data_dir, dataset_name=dataset_name, model_names=test_model_names).set_index(["session_id", "model_test"]).rename(columns={"output": "model_output"})
        results["model_output"] = results["model_output"].apply(lambda x: x[0])
        test_sample = test_sample.join(results.loc[:, ["model_output"]], how="inner").reset_index()
    else:
        test_sample["model_output"] = None
        test_sample = test_sample.reset_index()
    test_sample["judge"] = judge
    test_sample["task_id"] = task_id

    output_dir = os.path.join(data_dir, dataset_name, "score", judge)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    for model in tqdm(test_model_names):
        model_test_sample = test_sample[test_sample["model_test"] == model]
        output_file = os.path.join(output_dir, f"{model}.json")
        model_test_sample.to_json(output_file, orient='records', indent=2)
        logger.info(f"""Score of "{model}" output to "{output_file}" """)
