import random
import itertools
from dataclasses import dataclass
from typing import Optional, Tuple, Literal
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F


def set_seed(seed: int = 0):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def make_random_points(n: int, d_star: int, device="cpu") -> torch.Tensor:
    """Create n points in R^{d_star} uniformly on the unit sphere."""
    X = torch.randn(n, d_star, device=device)
    return X / (X.norm(dim=1, keepdim=True) + 1e-12)


@torch.no_grad()
def sample_triplets_from_gt(
    X: torch.Tensor,
    m: int,
    margin_ratio: Optional[float] = None,
    pick: Literal["min", "max"] = "min",
    replacement: bool = False,
    max_tries: int = 50,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    n = X.shape[0]
    i_list, j_list, k_list, y_list = [], [], [], []

    def triplet_label(xi, xj, xk):
        dij = (xi - xj).pow(2).sum().item()
        dik = (xi - xk).pow(2).sum().item()
        return dij, dik, 1 if dij < dik else 0

    def within_margin(dij, dik):
        if margin_ratio is None:
            return True
        r = max(dij, dik) / (min(dij, dik) + 1e-12)
        return ((pick == "max" and r >= margin_ratio) or
                (pick == "min" and r <= margin_ratio))

    if not replacement:
        total_triplets = n * (n - 1) * (n - 2) // 2
        if m > total_triplets:
            raise ValueError(f"m ({m}) is greater than the number of possible triplets ({total_triplets})")
        if total_triplets <= 3 * 10**6:
            indices = [
                (i, j, k)
                for i in range(n)
                for j, k in itertools.combinations((x for x in range(n) if x != i), 2)
            ]
            random.shuffle(indices)
            count = 0
            for i, j, k in tqdm(indices, desc="Sampling triplets (no replacement v1)"):
                xi = X[i]
                dij, dik, y = triplet_label(xi, X[j], X[k])
                if dij != dik and within_margin(dij, dik):
                    i_list.append(i)
                    j_list.append(j)
                    k_list.append(k)
                    y_list.append(y)
                    count += 1
                    if count >= m:
                        break
        else:
            seen = set()
            count, attempt = 0, 0
            with tqdm(total=m, desc="Sampling triplets (no replacement v2)") as pbar:
                while count < m and attempt < m * max_tries:
                    i, j, k = random.sample(range(n), k=3)
                    if j == k: attempt += 1; continue
                    key = (i, min(j, k), max(j, k))
                    if key in seen:
                        attempt += 1
                        continue
                    xi = X[i]
                    dij, dik, y = triplet_label(xi, X[j], X[k])
                    if dij != dik and within_margin(dij, dik):
                        i_list.append(i)
                        j_list.append(j)
                        k_list.append(k)
                        y_list.append(y)
                        seen.add(key)
                        count += 1
                        pbar.update(1)
                    attempt += 1
                if count < m:
                    raise ValueError(f"Could not sample {m} unique triplets after {m * max_tries} attempts")
    else:
        for _ in tqdm(range(m), desc="Sampling triplets (with replacement)"):
            for _try in range(max_tries):
                i, j, k = random.sample(range(n), k=3)
                xi = X[i]
                dij, dik, y = triplet_label(xi, X[j], X[k])
                if dij != dik and within_margin(dij, dik):
                    i_list.append(i)
                    j_list.append(j)
                    k_list.append(k)
                    y_list.append(y)
                    break
            else:
                i, j, k = random.sample(range(n), k=3)
                xi = X[i]
                _, _, y = triplet_label(xi, X[j], X[k])
                i_list.append(i)
                j_list.append(j)
                k_list.append(k)
                y_list.append(y)

    device = X.device
    return (torch.tensor(i_list, dtype=torch.long, device=device),
            torch.tensor(j_list, dtype=torch.long, device=device),
            torch.tensor(k_list, dtype=torch.long, device=device),
            torch.tensor(y_list, dtype=torch.float32, device=device))


def triplet_accuracy(scores: torch.Tensor, y: torch.Tensor) -> float:
    """
    Compute triplet accuracy: scores > 0 when y=1 (j closer), <0 when y=0.
    """
    pred = (scores > 0).float()
    return (pred == y).float().mean().item()

class FreeEmbeddingModel(nn.Module):
    def __init__(self, n: int, d: int, init_scale: float = 0.01):
        super().__init__()
        self.Z = nn.Parameter(init_scale * torch.randn(n, d))

    def embed(self) -> torch.Tensor:
        return self.Z


@dataclass
class TrainConfig:
    steps: int = 100000
    batch_size: int = 256
    lr: float = 0.003
    weight_decay: float = 1e-6
    margin: float = 1.0
    temperature: float = 1.0
    normalize: bool = True
    eval_every: int = 250


def compute_scores(Z: torch.Tensor, i: torch.Tensor, j: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    zi, zj, zk = Z[i], Z[j], Z[k]
    d2_ij = (zi - zj).pow(2).sum(dim=1)
    d2_ik = (zi - zk).pow(2).sum(dim=1)
    return d2_ik - d2_ij


@torch.no_grad()
def compute_scores_eval(Z, i, j, k):
    zi, zj, zk = Z[i], Z[j], Z[k]
    d2_ij = (zi - zj).pow(2).sum(dim=1)
    d2_ik = (zi - zk).pow(2).sum(dim=1)
    return d2_ik - d2_ij


def triplet_hinge_loss(
    scores: torch.Tensor,
    y: torch.Tensor,
    margin: float = 1.0,
    temperature: float = 1.0
) -> torch.Tensor:
    """
    y in {0,1}; want: score>0 if y=1, score<0 if y=0.
    Loss = mean(max(0, margin - (2y-1) * score / temperature))
    """
    sign = 2.0 * y - 1.0
    scaled_scores = sign * scores / temperature
    return F.relu(margin - scaled_scores).mean()


@torch.no_grad()
def row_normalize_(Z: torch.Tensor, eps: float = 1e-12):
    Z /= (Z.norm(dim=1, keepdim=True) + eps)


def train_one_setting(
    X: torch.Tensor,
    trip_train: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    trip_test: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    d: int,
    cfg: TrainConfig,
    run=None
) -> Tuple[float, float]:
    set_seed(0)
    device = X.device
    n, _ = X.shape
    model = FreeEmbeddingModel(n=n, d=d).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=2, factor=0.5, min_lr=1e-6)

    i_tr, j_tr, k_tr, y_tr = trip_train
    i_te, j_te, k_te, y_te = trip_test

    def eval_and_log(step, which="initial"):
        with torch.no_grad():
            Z_eval = model.embed()
            if cfg.normalize:
                Z_eval = Z_eval / (Z_eval.norm(dim=1, keepdim=True) + 1e-12)
            tr_scores = compute_scores_eval(Z_eval, i_tr, j_tr, k_tr)
            te_scores = compute_scores_eval(Z_eval, i_te, j_te, k_te)
            tr_acc = triplet_accuracy(tr_scores, y_tr)
            te_acc = triplet_accuracy(te_scores, y_te)
            print(
                f"[d={d:3d} | free] {which} step {step:4d} "
                f"train_acc {tr_acc:.3f} test_acc {te_acc:.3f}"
            )
        return tr_acc, te_acc

    tr_acc_init, te_acc_init = eval_and_log(0)
    s_tr = i_tr.shape[0]

    for step in range(1, cfg.steps + 1):
        idx = torch.randint(0, s_tr, (cfg.batch_size,), device=device)
        bi, bj, bk, by = i_tr[idx], j_tr[idx], k_tr[idx], y_tr[idx]
        Z = model.embed()
        scores = compute_scores(Z / (Z.norm(dim=1, keepdim=True) + 1e-12) if cfg.normalize else Z, bi, bj, bk)
        loss = triplet_hinge_loss(scores, by, temperature=cfg.temperature)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if cfg.normalize:
            with torch.no_grad():
                row_normalize_(model.embed())

        if step % cfg.eval_every == 0 or step == cfg.steps:
            tr_acc, te_acc = eval_and_log(step, which="")
            scheduler.step(loss.item())
            if tr_acc >= 0.99995 or loss.item() <= 0.001 or opt.param_groups[0]["lr"] <= scheduler.min_lrs[0] + 1e-12:
                print("Early stopping triggered.")
                break

    return tr_acc, te_acc


# ======== Example usage ========
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    set_seed(0)
    n = 50
    d_star = 5
    d = 5
    m_train = 5000
    m_test = 2000
    cfg = TrainConfig(steps=100000, batch_size=256, lr=0.003, weight_decay=1e-6, margin=1.0, temperature=1.0, normalize=True, eval_every=250)
    X = make_random_points(n=n, d_star=d_star, device=device)
    trip_train = sample_triplets_from_gt(X, m=m_train)
    trip_test = sample_triplets_from_gt(X, m=m_test)
    tr_acc, te_acc = train_one_setting(X=X, trip_train=trip_train, trip_test=trip_test, d=d, cfg=cfg)
    print("\nFinal Results")
    print(f"Train accuracy: {tr_acc:.4f}")
    print(f"Test accuracy:  {te_acc:.4f}")