import torch
import numpy as np


def extract_rules_structured(
    rules_model,
    X_scaled,
    Y_data,
    feature_names,
    scaler_x,
    responsibility_threshold,
    activation_threshold,
    assign_max_resp=False,
):
    """
    Internal method to extract rules into a structured format.
    It also prepares original-scale data for plotting.
    """
    with torch.no_grad():
        if feature_names is None:
            feature_names = [f"Feature {i}" for i in range(X_scaled.shape[1])]

        # Get both responsibilities (normalized) and raw activations (unnormalized)
        responsibilities, _ = rules_model.forward(X_scaled)
        activations = rules_model.forward_raw(X_scaled)
        argmax_assignments = torch.argmax(responsibilities, dim=1)

        if scaler_x:
            X_orig = torch.tensor(
                scaler_x.inverse_transform(X_scaled.cpu().numpy()),
                device=X_scaled.device,
            )
            scaled_limits = rules_model.rules[0].limits.cpu().numpy()
            scaled_mins = scaled_limits[:, 0].reshape(1, -1)
            scaled_maxs = scaled_limits[:, 1].reshape(1, -1)
            orig_mins = scaler_x.inverse_transform(scaled_mins).flatten()
            orig_maxs = scaler_x.inverse_transform(scaled_maxs).flatten()
            data_limits_orig = torch.tensor(
                np.stack([orig_mins, orig_maxs], axis=1), device=X_scaled.device
            )
        else:
            X_orig = X_scaled.clone()
            data_limits_orig = rules_model.rules[0].limits.clone()

        extracted_rules = []
        for comp_idx, rule in enumerate(rules_model.rules):
            if rule.disabled:
                continue

            # A sample is included if it's the winner (argmax), or if its
            # responsibility OR its raw activation exceeds the respective thresholds.
            responsibility_mask = (
                responsibilities[:, comp_idx] > responsibility_threshold
            )
            activation_mask = activations[:, comp_idx] > activation_threshold
            argmax_mask = argmax_assignments == comp_idx

            # combined_mask = responsibility_mask | activation_mask | argmax_mask
            combined_mask = responsibility_mask & activation_mask
            if assign_max_resp:
                combined_mask = combined_mask | argmax_mask
            n_samples = combined_mask.sum().item()

            rule_data = {
                "component": comp_idx,
                "samples": n_samples,
                "predicates": {},
                "target_values": Y_data[combined_mask].cpu().numpy()
                if n_samples > 0
                else np.array([]),
            }

            if n_samples > 0:
                comp_samples_orig = X_orig[combined_mask]
                mins = torch.min(comp_samples_orig, dim=0).values
                maxs = torch.max(comp_samples_orig, dim=0).values
                means = torch.mean(comp_samples_orig, dim=0)

                for feat_idx in range(X_scaled.shape[1]):
                    feat_name = feature_names[feat_idx]
                    feat_weight = rule.and_layer.and_weights[0][feat_idx].item()
                    rule_data["predicates"][feat_name] = {
                        "min": mins[feat_idx].item(),
                        "max": maxs[feat_idx].item(),
                        "mean": means[feat_idx].item(),
                        "values": comp_samples_orig[:, feat_idx].cpu().numpy(),
                        "is_discrete": rule.discretizer.is_discrete[feat_idx],
                        "data_min": data_limits_orig[feat_idx, 0].item(),
                        "data_max": data_limits_orig[feat_idx, 1].item(),
                        "weight": feat_weight,
                    }
            extracted_rules.append(rule_data)

    return extracted_rules, X_orig, Y_data


def format_rules_as_text_table(structured_rules, weight_threshold=0.1):
    """Formats the structured rules into a text-based table."""
    if not structured_rules:
        return "No active rules to display."

    all_features = (
        list(structured_rules[0]["predicates"].keys())
        if structured_rules and structured_rules[0]["predicates"]
        else []
    )
    header = ["Component", "Samples"] + all_features
    col_widths = [len(h) for h in header]

    rows_data = []
    for i, rule in enumerate(structured_rules):
        row = [f"Comp {rule['component'] + 1}", str(rule["samples"])]
        col_widths[0] = max(col_widths[0], len(row[0]))
        col_widths[1] = max(col_widths[1], len(row[1]))

        for j, feat_name in enumerate(all_features):
            pred = rule["predicates"].get(feat_name)
            cell_str = "-"
            if pred:
                is_active = pred["weight"] > weight_threshold
                if pred["is_discrete"]:
                    val_str = (
                        f"{'T' if pred['min'] > 0.5 else 'F'}"
                        if abs(pred["min"] - pred["max"]) < 1e-6
                        else "Mixed"
                    )
                    cell_str = f"({val_str})" if not is_active else val_str
                else:
                    interval_str = f"[{pred['min']:.2f}, {pred['max']:.2f}]"
                    cell_str = f"({interval_str})" if not is_active else interval_str
            row.append(cell_str)
            col_widths[j + 2] = max(col_widths[j + 2], len(cell_str))
        rows_data.append(row)

    separator = "+".join("-" * (w + 2) for w in col_widths)
    header_str = " | ".join(h.center(w) for h, w in zip(header, col_widths))
    table_lines = [separator, f"| {header_str} |", separator]
    total_samples = sum(r["samples"] for r in structured_rules)

    for row_data in rows_data:
        row_str = " | ".join(d.ljust(w) for d, w in zip(row_data, col_widths))
        table_lines.append(f"| {row_str} |")

    table_lines.append(separator)
    return "\n".join(table_lines)


def format_rules_as_html(structured_rules, weight_threshold=0.1):
    """Formats the structured rules into an HTML table."""
    if not structured_rules:
        return "<p>No active rules to display.</p>"

    all_features = (
        list(structured_rules[0]["predicates"].keys())
        if structured_rules and structured_rules[0]["predicates"]
        else []
    )
    html = """
    <style>
        .rules-table { border-collapse: collapse; width: 100%; font-family: sans-serif; table-layout: auto; }
        .rules-table th, .rules-table td { border: 1px solid #ddd; padding: 8px; text-align: left; white-space: nowrap; }
        .rules-table th { background-color: #f2f2f2; font-weight: bold; position: sticky; top: 0; }
        .rules-table tr:nth-child(even){ background-color: #f9f9f9; }
        .rules-table .inactive { color: #999; font-style: italic; }
    </style>
    <div style="overflow-x: auto;">
    <table class="rules-table">
    """
    html += "<thead><tr><th>Component</th><th>Samples</th>"
    for feat_name in all_features:
        html += f"<th>{feat_name}</th>"
    html += "</tr></thead><tbody>"

    for rule in structured_rules:
        html += f"<tr><td><b>Comp {rule['component'] + 1}</b></td><td>{rule['samples']}</td>"
        for feat_name in all_features:
            pred = rule["predicates"].get(feat_name, {})
            cell_content = "-"
            css_class = ""
            if pred:
                if not (pred.get("weight", 0) > weight_threshold):
                    css_class = "inactive"
                if pred.get("is_discrete"):
                    cell_content = (
                        f"{'True' if pred.get('min', 0) > 0.5 else 'False'}"
                        if abs(pred.get("min", 0) - pred.get("max", 0)) < 1e-6
                        else "Mixed"
                    )
                else:
                    cell_content = (
                        f"[{pred.get('min', 0):.2f}, {pred.get('max', 0):.2f}]"
                    )
            html += f"<td class='{css_class}'>{cell_content}</td>"
        html += "</tr>"

    html += "</tbody></table></div>"
    return html
