import numpy as np
import torch
from typing import Dict, Tuple
import math

from .config import ExperimentConfig
from .model import MSNWedge2D
from .dataset import sample_evaluation_grid, get_true_solution


@torch.no_grad()
def evaluate_model(
    model: MSNWedge2D,
    config: ExperimentConfig,
    n_r: int = 50,
    n_theta: int = 50,
) -> Dict[str, float]:
    model.eval()
    device = config.device
    omega = config.omega
    true_mu = config.true_mu
    bc_type = config.bc_type

    mus = model.get_exponents()
    coeffs = model.get_coeffs()
    coeffs_abs = np.abs(coeffs)

    dominant_idx = np.argmax(coeffs_abs)
    dominant_mu = mus[dominant_idx]

    abs_err_mu = abs(dominant_mu - true_mu)
    rel_err_mu = abs_err_mu / abs(true_mu) * 100

    if bc_type in ["DD", "NN"]:
        constraint_vals = np.sin(mus * omega) ** 2
    else:
        constraint_vals = np.cos(mus * omega) ** 2

    weights = coeffs_abs / (coeffs_abs.sum() + 1e-8)
    constraint_violation = float(np.sum(weights * constraint_vals))

    R, TH, X, Y = sample_evaluation_grid(n_r, n_theta, omega, r_min=1e-3)
    u_true = get_true_solution(R, TH, true_mu, bc_type)

    r_flat = torch.tensor(R.flatten(), dtype=torch.float32, device=device)
    theta_flat = torch.tensor(TH.flatten(), dtype=torch.float32, device=device)
    u_pred_flat = model.forward(r_flat, theta_flat).cpu().numpy()
    u_pred = u_pred_flat.reshape(R.shape)

    l2_err = np.sqrt(np.mean((u_pred - u_true) ** 2))
    l2_norm = np.sqrt(np.mean(u_true**2))
    solution_l2_rel = l2_err / (l2_norm + 1e-12)

    n_edge_eval = 100
    r_edge = np.logspace(-3, 0, n_edge_eval)
    r_edge_t = torch.tensor(r_edge, dtype=torch.float32, device=device)

    dirichlet_loss, neumann_loss = model.edge_bc_loss_theta_omega(r_edge_t, omega)
    bc_dirichlet_err = float(dirichlet_loss.cpu().numpy())
    bc_neumann_err = float(neumann_loss.cpu().numpy())

    if bc_type in ["DD", "NN"]:
        pred_mode = round(dominant_mu * omega / math.pi)
        true_mode = round(true_mu * omega / math.pi)
    else:
        pred_mode = round(dominant_mu * omega / math.pi + 0.5)
        true_mode = round(true_mu * omega / math.pi + 0.5)

    mode_error = abs(pred_mode - true_mode)

    return {
        "rel_err_mu": rel_err_mu,
        "abs_err_mu": abs_err_mu,
        "dominant_mu": dominant_mu,
        "true_mu": true_mu,
        "constraint_violation": constraint_violation,
        "solution_l2_rel": solution_l2_rel,
        "bc_dirichlet_err": bc_dirichlet_err,
        "bc_neumann_err": bc_neumann_err,
        "mode_error": mode_error,
        "exponents": mus.tolist(),
        "coefficients": coeffs.tolist(),
    }


def compute_success_rate(
    results: list,
    threshold: float = 5.0,
) -> float:
    if not results:
        return 0.0

    successes = sum(1 for r in results if r["rel_err_mu"] < threshold)
    return successes / len(results) * 100


def compute_statistics(
    results: list,
    metric: str = "rel_err_mu",
) -> Dict[str, float]:
    if not results:
        return {
            "mean": float("nan"),
            "std": float("nan"),
            "median": float("nan"),
            "p90": float("nan"),
            "min": float("nan"),
            "max": float("nan"),
        }

    values = [r[metric] for r in results]
    values = np.array(values)

    return {
        "mean": float(np.mean(values)),
        "std": float(np.std(values)),
        "median": float(np.median(values)),
        "p90": float(np.percentile(values, 90)),
        "min": float(np.min(values)),
        "max": float(np.max(values)),
    }


def categorize_results(results: list) -> Dict[str, list]:
    categories = {
        "all": results,
        "convex": [r for r in results if r.get("omega_deg", 0) <= 180],
        "reentrant": [r for r in results if r.get("omega_deg", 0) > 180],
        "DD": [r for r in results if r.get("bc_type") == "DD"],
        "NN": [r for r in results if r.get("bc_type") == "NN"],
        "DN": [r for r in results if r.get("bc_type") == "DN"],
        "ND": [r for r in results if r.get("bc_type") == "ND"],
        "naive": [r for r in results if r.get("method") == "naive"],
        "constraint": [r for r in results if r.get("method") == "constraint"],
    }
    return categories


def generate_summary_statistics(results: list) -> dict:
    categories = categorize_results(results)
    summary = {}

    for cat_name, cat_results in categories.items():
        if not cat_results:
            continue

        summary[cat_name] = {
            "n_experiments": len(cat_results),
            "success_rate": compute_success_rate(cat_results),
            "rel_err_mu": compute_statistics(cat_results, "rel_err_mu"),
            "constraint_violation": compute_statistics(
                cat_results, "constraint_violation"
            ),
            "solution_l2_rel": compute_statistics(cat_results, "solution_l2_rel"),
        }

    return summary
