import os
import numpy as np
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

OUTPUT_DIR = "greedy_samples"
os.makedirs(OUTPUT_DIR, exist_ok=True)


class KnowledgePointEntropyAnalyzer:
    def __init__(self, alpha: float = 1e-6):
        self.alpha = alpha

    def add_background(self, B: np.ndarray) -> np.ndarray:
        n, M = B.shape
        background = self.alpha / (n * M)
        return B + background

    def normalize_to_probability(self, B_prime: np.ndarray) -> np.ndarray:
        S = np.sum(B_prime)
        return B_prime / S

    def calculate_type2_entropy(self, P: np.ndarray) -> float:
        P_flat = P.flatten()
        P_nonzero = P_flat[P_flat > 0]
        return -np.sum(P_nonzero * np.log2(P_nonzero))

    def analyze(self, B: np.ndarray) -> float:
        B_prime = self.add_background(B)
        P = self.normalize_to_probability(B_prime)
        return self.calculate_type2_entropy(P)


def load_messages(txt_path: str) -> List[str]:
    with open(txt_path, "r", encoding="utf-8") as f:
        return [line.strip() for line in f.readlines()]


def save_sub_data(sample_size: int, sub_matrix: np.ndarray, sub_indices: List[int], messages: List[str]):
    matrix_path = os.path.join(OUTPUT_DIR, f"sub_matrix_{sample_size}_plain.npy")
    text_path = os.path.join(OUTPUT_DIR, f"sub_dataset_{sample_size}_plain.txt")
    np.save(matrix_path, sub_matrix)
    with open(text_path, "w", encoding="utf-8") as f:
        for i in sub_indices:
            f.write(messages[i] + "\n")


def greedy_entropy_sampling_optimized(matrix: np.ndarray, n_select: int, return_index: bool = False,
                                      alpha: float = 1.0, beta: float = 1.0) -> Tuple[np.ndarray, List[int]]:
    n, m = matrix.shape
    B = matrix.copy()

    def binary_entropy(p):
        if p == 0 or p == 1:
            return 0
        return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

    def row_entropy(row_vec):
        p = np.sum(row_vec) / len(row_vec)
        return binary_entropy(p)

    p_j = np.mean(B, axis=0)
    H_j = np.array([binary_entropy(p) for p in p_j])
    sorted_col_indices = np.argsort(-H_j)

    selected_rows = set()
    kp_counter = np.zeros(m, dtype=int)
    covered_kps = set()

    for col in sorted_col_indices:
        rows_with_col = np.where(B[:, col] == 1)[0]
        candidate_rows = [row for row in rows_with_col if row not in selected_rows]

        scored_candidates = []
        for row in candidate_rows:
            row_vec = B[row]
            row_kps = set(np.where(row_vec == 1)[0])
            new_kps = row_kps - covered_kps
            if not new_kps:
                continue
            entropy_score = row_entropy(row_vec)
            new_kp_score = len(new_kps)
            total_score = alpha * entropy_score + beta * new_kp_score
            scored_candidates.append((total_score, row))

        scored_candidates.sort(reverse=True)
        for _, row in scored_candidates:
            selected_rows.add(row)
            kp_counter += B[row]
            covered_kps.update(np.where(B[row] == 1)[0])
            if len(covered_kps) == m or len(selected_rows) >= n_select:
                break
        if len(covered_kps) == m or len(selected_rows) >= n_select:
            break

    if len(selected_rows) < n_select:
        remaining_rows = list(set(range(n)) - selected_rows)
        scored_remaining = []
        for row in remaining_rows:
            row_vec = B[row]
            entropy_score = row_entropy(row_vec)
            underrepresented_kps = (row_vec > 0) * (kp_counter == kp_counter.min())
            balance_score = np.sum(underrepresented_kps)
            total_score = alpha * entropy_score + beta * balance_score
            scored_remaining.append((total_score, row))

        scored_remaining.sort(reverse=True)
        for _, row in scored_remaining:
            selected_rows.add(row)
            kp_counter += B[row]
            if len(selected_rows) >= n_select:
                break

    if len(selected_rows) < n_select:
        remaining = list(set(range(n)) - selected_rows)
        supplement = np.random.choice(remaining, size=n_select - len(selected_rows), replace=False)
        selected_rows.update(supplement)

    selected_rows = sorted(list(selected_rows))
    selected_matrix = B[selected_rows]

    return (selected_matrix, selected_rows) if return_index else selected_matrix


def greedy_sample_with_entropy(B: np.ndarray, 
                              indices: np.ndarray,
                              messages: List[str],
                              sample_sizes: List[int],
                              alpha: float = 1e-6) -> pd.DataFrame:
    analyzer = KnowledgePointEntropyAnalyzer(alpha)
    records = []

    prev_size = None
    prev_entropy_norm = None

    for size in tqdm(sample_sizes, desc="Greedy Sampling with Entropy"):
        sub_matrix, selected_idx = greedy_entropy_sampling_optimized(B, size, return_index=True)
        sub_index = indices[selected_idx]
        save_sub_data(size, sub_matrix, sub_index, messages)

        entropy_val = analyzer.analyze(sub_matrix)
        log_n = np.log2(size)
        entropy_norm = entropy_val / log_n

        kp_distribution = np.sum(sub_matrix, axis=0)

        # Calculate slope between consecutive points
        slope = np.nan
        if prev_size is not None and prev_entropy_norm is not None:
            delta_entropy = entropy_norm - prev_entropy_norm
            delta_size = size - prev_size
            slope = delta_entropy / delta_size if delta_size != 0 else np.nan

        records.append({
            "sample_size": size,
            "H_element": entropy_val,
            "H_element_norm": entropy_norm,
            "slope": slope,
            "knowledge_point_distribution": kp_distribution.tolist(),
        })

        prev_size = size
        prev_entropy_norm = entropy_norm

    df = pd.DataFrame(records)
    return df


def random_sample_with_entropy(B: np.ndarray, 
                              indices: np.ndarray,
                              messages: List[str],
                              sample_sizes: List[int],
                              alpha: float = 1e-6) -> pd.DataFrame:
    analyzer = KnowledgePointEntropyAnalyzer(alpha)
    records = []

    n = B.shape[0]
    prev_size = None
    prev_entropy_norm = None

    for size in tqdm(sample_sizes, desc="Random Sampling with Entropy"):
        selected_idx = np.random.choice(n, size=size, replace=False)
        sub_matrix = B[selected_idx]

        entropy_val = analyzer.analyze(sub_matrix)
        log_n = np.log2(size)
        entropy_norm = entropy_val / log_n

        kp_distribution = np.sum(sub_matrix, axis=0)

        # Calculate slope between consecutive points
        slope = np.nan
        if prev_size is not None and prev_entropy_norm is not None:
            delta_entropy = entropy_norm - prev_entropy_norm
            delta_size = size - prev_size
            slope = delta_entropy / delta_size if delta_size != 0 else np.nan

        records.append({
            "sample_size": size,
            "H_element": entropy_val,
            "H_element_norm": entropy_norm,
            "slope": slope,
            "knowledge_point_distribution": kp_distribution.tolist(),
        })

        prev_size = size
        prev_entropy_norm = entropy_norm

    df = pd.DataFrame(records)
    return df


if __name__ == "__main__":
    path = "kp_matrix.npy"
    text_path = "diabetes.txt"
    matrix_raw = np.load(path)
    B = matrix_raw[:, :-1]
    indices = matrix_raw[:, -1].astype(int)
    messages = load_messages(text_path)

    SAMPLE_SIZES = list(range(100, 1000, 10))

    df_greedy = greedy_sample_with_entropy(B, indices, messages, SAMPLE_SIZES, alpha=1e-6)
    df_greedy.to_csv("greedy_entropy_curve.csv", index=False)

    df_random = random_sample_with_entropy(B, indices, messages, SAMPLE_SIZES, alpha=1e-6)
    df_random.to_csv("random_entropy_curve.csv", index=False)
