import math
import numpy as np
import torch
from typing import Tuple, Dict, Optional


def set_seed(seed: int):
    import random

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_device(gpu_id: Optional[int] = None) -> torch.device:
    if gpu_id is not None and torch.cuda.is_available():
        return torch.device(f"cuda:{gpu_id}")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def sample_arc_points(n: int, omega: float, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    theta = rng.uniform(0, omega, n).astype(np.float32)
    x = np.cos(theta)
    y = np.sin(theta)
    return np.stack([x, y], axis=1).astype(np.float32)


def sample_edge_points(
    n: int,
    omega: float,
    edge: str,
    r_min: float = 1e-4,
    r_power: float = 2.0,
    seed: int = 0,
) -> np.ndarray:
    rng = np.random.default_rng(seed)
    t = rng.uniform(0, 1, n).astype(np.float32)
    r = r_min + (1.0 - r_min) * (t**r_power)

    if edge == "theta0":
        theta = np.zeros(n, dtype=np.float32)
    else:
        theta = np.full(n, omega, dtype=np.float32)

    x = r * np.cos(theta)
    y = r * np.sin(theta)
    return np.stack([x, y], axis=1).astype(np.float32)


def sample_interior_points(
    n: int, omega: float, r_min: float = 1e-4, r_power: float = 2.0, seed: int = 0
) -> np.ndarray:
    rng = np.random.default_rng(seed)
    t = rng.uniform(0, 1, n).astype(np.float32)
    r = r_min + (1.0 - r_min) * (t**r_power)
    theta = rng.uniform(0, omega, n).astype(np.float32)

    x = r * np.cos(theta)
    y = r * np.sin(theta)
    return np.stack([x, y], axis=1).astype(np.float32)


def get_true_exponents(omega: float, bc_type: str, n_modes: int = 3) -> np.ndarray:
    bc_type = bc_type.upper()
    exponents = []
    for m in range(1, n_modes + 1):
        if bc_type in ["DD", "NN"]:
            mu = m * math.pi / omega
        else:
            mu = (m - 0.5) * math.pi / omega
        exponents.append(mu)
    return np.array(exponents, dtype=np.float32)


def compute_true_solution(
    xy: np.ndarray, omega: float, bc_type: str, mode: int = 1, coefficient: float = 1.0
) -> np.ndarray:
    bc_type = bc_type.upper()
    x, y = xy[:, 0], xy[:, 1]
    r = np.sqrt(x**2 + y**2 + 1e-12)
    theta = np.arctan2(y, x)
    theta = np.where(theta < 0, theta + 2 * np.pi, theta)

    if bc_type in ["DD", "NN"]:
        mu = mode * math.pi / omega
    else:
        mu = (mode - 0.5) * math.pi / omega

    radial = np.power(r, mu)

    if bc_type in ["DD", "DN"]:
        angular = np.sin(mu * theta)
    else:
        angular = np.cos(mu * theta)

    return coefficient * radial * angular


def compute_arc_boundary_target(
    theta: np.ndarray,
    omega: float,
    bc_type: str,
    mode: int = 1,
    coefficient: float = 1.0,
) -> np.ndarray:
    bc_type = bc_type.upper()

    if bc_type in ["DD", "NN"]:
        mu = mode * math.pi / omega
    else:
        mu = (mode - 0.5) * math.pi / omega

    if bc_type in ["DD", "DN"]:
        return coefficient * np.sin(mu * theta)
    else:
        return coefficient * np.cos(mu * theta)


def compute_exponent_error(predicted_mu: float, true_mu: float) -> Tuple[float, float]:
    abs_error = abs(predicted_mu - true_mu)
    rel_error = abs_error / abs(true_mu) * 100 if true_mu != 0 else float("inf")
    return abs_error, rel_error


def compute_constraint_violation(
    mu: np.ndarray, omega: float, bc_type: str, coeffs: Optional[np.ndarray] = None
) -> float:
    bc_type = bc_type.upper()

    if coeffs is not None:
        weights = np.abs(coeffs) / (np.sum(np.abs(coeffs)) + 1e-8)
    else:
        weights = np.ones_like(mu) / len(mu)

    if bc_type in ["DD", "NN"]:
        violations = np.sin(mu * omega) ** 2
    else:
        violations = np.cos(mu * omega) ** 2

    return float(np.sum(weights * violations))


def compute_solution_l2_error(
    model,
    omega: float,
    bc_type: str,
    mode: int = 1,
    n_eval: int = 1000,
    device: torch.device = torch.device("cpu"),
    seed: int = 42,
) -> Tuple[float, float]:
    xy_eval = sample_interior_points(n_eval, omega, r_min=1e-3, r_power=2.0, seed=seed)

    u_true = compute_true_solution(xy_eval, omega, bc_type, mode=mode)

    model.eval()
    with torch.no_grad():
        xy_t = torch.tensor(xy_eval, device=device)
        u_pred = model(xy_t).cpu().numpy()

    diff = u_pred - u_true
    l2_error = float(np.sqrt(np.mean(diff**2)))
    true_norm = float(np.sqrt(np.mean(u_true**2)))
    rel_l2_error = (l2_error / true_norm * 100) if true_norm > 1e-8 else float("inf")

    return l2_error, rel_l2_error


def generate_omega_grid(
    n_omega: int = 30,
    omega_min: float = math.pi / 2,
    omega_max: float = 11 * math.pi / 6,
) -> np.ndarray:
    return np.linspace(omega_min, omega_max, n_omega).astype(np.float32)


def get_experiment_configs(
    n_omega: int = 30,
    bc_types: Tuple[str, ...] = ("DD", "NN", "DN", "ND"),
    methods: Tuple[str, ...] = ("naive", "constraint"),
    seeds: Tuple[int, ...] = (0, 1, 2),
    omega_min: float = math.pi / 2,
    omega_max: float = 11 * math.pi / 6,
) -> list:
    omegas = generate_omega_grid(n_omega, omega_min, omega_max)
    configs = []

    for omega in omegas:
        for bc_type in bc_types:
            for method in methods:
                for seed in seeds:
                    config = {
                        "omega": float(omega),
                        "bc_type": bc_type,
                        "method": method,
                        "seed": int(seed),
                    }
                    configs.append(config)

    return configs


def plot_exponent_trajectory(
    logs: Dict,
    true_mu: float,
    title: str = "Exponent Trajectory",
    save_path: Optional[str] = None,
):
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    steps = logs.get("steps", logs.get("step", []))
    mus = np.array(logs.get("mus", []))

    if len(mus) == 0:
        return None

    fig, ax = plt.subplots(figsize=(10, 5))

    for k in range(mus.shape[1]):
        ax.plot(steps, mus[:, k], linewidth=1.5, alpha=0.7, label=f"μ[{k}]")

    ax.axhline(
        true_mu,
        linestyle="--",
        color="red",
        linewidth=2,
        label=f"Target μ={true_mu:.4f}",
    )
    ax.set_xlabel("Step")
    ax.set_ylabel("Exponent μ")
    ax.set_title(title)
    ax.legend(loc="best")
    ax.grid(True, alpha=0.3)

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
    else:
        return fig


def plot_loss_curves(
    logs: Dict,
    title: str = "Loss Curves",
    save_path: Optional[str] = None,
):
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    steps = logs.get("steps", logs.get("step", []))

    losses = logs.get("losses", logs.get("loss", []))
    if losses:
        axes[0].semilogy(steps, losses, "b-", linewidth=1.5)
    axes[0].set_xlabel("Step")
    axes[0].set_ylabel("Total Loss")
    axes[0].set_title("Total Loss")
    axes[0].grid(True, alpha=0.3)

    if "arc_losses" in logs:
        axes[1].semilogy(steps, logs["arc_losses"], label="Arc BC", linewidth=1.5)
    if "edge_losses" in logs:
        axes[1].semilogy(steps, logs["edge_losses"], label="Edge BC", linewidth=1.5)
    if "constraint_losses" in logs:
        axes[1].semilogy(
            steps, logs["constraint_losses"], label="Constraint", linewidth=1.5
        )
    axes[1].set_xlabel("Step")
    axes[1].set_ylabel("Loss")
    axes[1].set_title("Loss Components")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    fig.suptitle(title)
    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
    else:
        return fig


def plot_results_heatmap(
    results_df,
    metric: str = "rel_err_mu",
    method: str = "constraint",
    title: Optional[str] = None,
    save_path: Optional[str] = None,
):
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import pandas as pd

    df = results_df[results_df["method"] == method].copy()

    pivot = df.pivot_table(
        values=metric, index="bc_type", columns="omega_deg", aggfunc="mean"
    )

    fig, ax = plt.subplots(figsize=(14, 4))

    im = ax.imshow(
        pivot.values,
        aspect="auto",
        cmap="RdYlGn_r" if metric == "rel_err_mu" else "viridis",
    )

    ax.set_xticks(range(len(pivot.columns)))
    ax.set_xticklabels([f"{c:.0f}" for c in pivot.columns], rotation=45)
    ax.set_yticks(range(len(pivot.index)))
    ax.set_yticklabels(pivot.index)

    ax.set_xlabel("Wedge Angle ω (degrees)")
    ax.set_ylabel("BC Type")

    if title:
        ax.set_title(title)
    else:
        ax.set_title(f"{metric} ({method})")

    plt.colorbar(im, ax=ax)
    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
    else:
        return fig


def aggregate_worker_results(
    output_dir: str,
    n_workers: int,
    output_csv: str = "exp7.csv",
    output_json: str = "exp7_stats.json",
):
    import pandas as pd
    import json
    import os

    all_results = []

    for worker_id in range(n_workers):
        csv_path = os.path.join(output_dir, f"exp7_worker{worker_id}.csv")
        if os.path.exists(csv_path):
            df = pd.read_csv(csv_path)
            all_results.extend(df.to_dict("records"))
            print(f"Loaded {len(df)} results from worker {worker_id}")

    if not all_results:
        print("No results found!")
        return None, None

    combined_df = pd.DataFrame(all_results)
    combined_path = os.path.join(output_dir, output_csv)
    combined_df.to_csv(combined_path, index=False)
    print(f"Saved combined results to {combined_path}")

    return combined_df


def get_gpu_info() -> Dict:
    if not torch.cuda.is_available():
        return {"available": False, "count": 0}

    info = {
        "available": True,
        "count": torch.cuda.device_count(),
        "devices": [],
    }

    for i in range(info["count"]):
        props = torch.cuda.get_device_properties(i)
        info["devices"].append(
            {
                "id": i,
                "name": props.name,
                "memory_gb": props.total_memory / 1e9,
            }
        )

    return info


if __name__ == "__main__":
    omega = 3 * np.pi / 2
    bc_type = "DD"

    print(f"Testing utilities for omega={omega:.4f} ({np.degrees(omega):.1f} deg)")

    xy_arc = sample_arc_points(100, omega, seed=0)
    print(f"Arc points shape: {xy_arc.shape}")

    xy_edge0 = sample_edge_points(50, omega, "theta0", seed=0)
    print(f"Edge theta=0 points shape: {xy_edge0.shape}")

    xy_interior = sample_interior_points(500, omega, seed=0)
    print(f"Interior points shape: {xy_interior.shape}")

    true_mus = get_true_exponents(omega, bc_type, n_modes=3)
    print(f"True exponents for {bc_type}: {true_mus}")

    u_true = compute_true_solution(xy_interior, omega, bc_type, mode=1)
    print(
        f"True solution shape: {u_true.shape}, range: [{u_true.min():.4f}, {u_true.max():.4f}]"
    )

    configs = get_experiment_configs(n_omega=5, seeds=(0,))
    print(
        f"Number of experiment configs (5 angles, 4 BCs, 2 methods, 1 seed): {len(configs)}"
    )
