import os
import copy
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt


# ---------------------------
# Model
# ---------------------------
class MLPBinary2Logits(nn.Module):
    """
    Binary classifier that returns logits for both classes: shape (batch, 2).
    """
    def __init__(self, d: int, hidden: int = 128, depth: int = 2, dropout: float = 0.0):
        super().__init__()
        layers = []
        in_dim = d
        for _ in range(depth):
            layers += [nn.Linear(in_dim, hidden), nn.ReLU()]
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            in_dim = hidden
        layers.append(nn.Linear(in_dim, 2))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)  # (batch, 2)

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        logits = self.forward(x)
        probs = torch.softmax(logits, dim=1)
        return torch.round(probs[:,1])

    def single_predict(self, x: torch.Tensor) -> torch.Tensor:
        logits = self.forward(x)
        probs = torch.softmax(logits, dim=0)
        return torch.round(probs[1])


# ---------------------------
# Training / Evaluation
# ---------------------------
@dataclass
class TrainConfig:
    batch_size: int = 32
    lr: float = 1e-3
    weight_decay: float = 0.0
    epochs: int = 50
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 0


@torch.no_grad()
def dataset_loss_ce(model: nn.Module, loader: DataLoader, device: str) -> float:
    """
    Mean cross-entropy over the dataset.
    """
    model.eval()
    loss_fn = nn.CrossEntropyLoss(reduction="sum")
    total_loss, total_n = 0.0, 0

    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = loss_fn(logits, yb)
        total_loss += float(loss.item())
        total_n += xb.size(0)

    return total_loss / max(total_n, 1)


def train_to_optimum(
    X: np.ndarray,
    y: np.ndarray,
    model: nn.Module,
    cfg: TrainConfig
) -> Tuple[nn.Module, float, List[float]]:
    """
    Computes the optimal model
    """
    assert X.ndim == 2, "X must be (n, d)"
    assert y.ndim == 1 and y.shape[0] == X.shape[0], "y must be shape (n,)"
    assert set(np.unique(y)).issubset({0, 1}), "y must be binary {0,1}"

    device = cfg.device
    model = model.to(device)

    X_t = torch.from_numpy(X).float()
    y_t = torch.from_numpy(y).long()
    ds = TensorDataset(X_t, y_t)

    loader = DataLoader(ds, batch_size=cfg.batch_size, shuffle=True, drop_last=False, num_workers=cfg.num_workers)

    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    model.train()
    epoch_losses: List[float] = []
    for _ in range(cfg.epochs):
        total_loss = 0.0
        total_n = 0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits, yb)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            total_loss += float(loss.item()) * xb.size(0)
            total_n += xb.size(0)
        epoch_losses.append(total_loss / max(total_n, 1))

    eval_loader = DataLoader(ds, batch_size=max(cfg.batch_size, 512), shuffle=False, drop_last=False)
    base_loss = dataset_loss_ce(model, eval_loader, device)
    return model, base_loss, epoch_losses


# ---------------------------
# Per-point ascent under loss constraint
# ---------------------------
def maximize_fx_i_with_constraint(
    start_state: Dict[str, torch.Tensor],
    X: np.ndarray,
    y: np.ndarray,
    model_ctor: Callable[[], nn.Module],
    i: int,
    base_loss: float,
    epsilon: float,
    target_class: int = 1,
    ascent_lr: float = 1e-2,
    max_steps: int = 2000,
    eval_every: int = 10,
    device: Optional[str] = None,
    prob_objective_eps: float = 1e-12,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, object]]:
    """
    Uses gradient ascent to maximize the post-softmax probability of a given class
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    f = model_ctor().to(device)
    f.load_state_dict(start_state, strict=True)

    X_t = torch.from_numpy(X).float()
    y_t = torch.from_numpy(y).long()
    ds = TensorDataset(X_t, y_t)
    eval_loader = DataLoader(ds, batch_size=512, shuffle=False)

    xi = torch.from_numpy(X[i]).float().unsqueeze(0).to(device)

    opt = torch.optim.SGD(f.parameters(), lr=ascent_lr)

    current_loss = dataset_loss_ce(f, eval_loader, device)
    snapshot_states: List[Dict[str, torch.Tensor]] = []
    snapshot_steps: List[int] = []

    step = 0
    while (current_loss < base_loss + epsilon) and (step < max_steps):
        f.train()
        logits_i = f(xi)                 # (1, 2)
        probs_i = torch.softmax(logits_i, dim=1)  # (1, 2)
        p_t = probs_i[:, target_class].mean()

        # Gradient ascent on probability
        # objective = -p_t

        # If you want a less-saturating objective, use:
        objective = -torch.log(p_t + prob_objective_eps)

        opt.zero_grad(set_to_none=True)
        objective.backward()
        opt.step()

        step += 1
        if step % eval_every == 0:
            current_loss = dataset_loss_ce(f, eval_loader, device)
            if current_loss < base_loss + epsilon:
                snapshot_states.append(copy.deepcopy(f.state_dict()))
            else:
                print(current_loss, base_loss, step)

    assert (len(snapshot_states) > 0)
    f.load_state_dict(snapshot_states[-1], strict=True)

    final_loss = dataset_loss_ce(f, eval_loader, device)
    with torch.no_grad():
        logits_i = f(xi).squeeze(0)
        probs_i = torch.softmax(logits_i, dim=0)
        final_logits = logits_i.detach().cpu().numpy()
        final_probs = probs_i.detach().cpu().numpy()

    stats = {
        "i": i,
        "steps": step,
        "base_loss": float(base_loss),
        "base_plus_eps": float(base_loss + epsilon),
        "final_loss": float(final_loss),
        "target_class": int(target_class),
        "final_logit_class0": float(final_logits[0]),
        "final_logit_class1": float(final_logits[1]),
        "final_prob_class0": float(final_probs[0]),
        "final_prob_class1": float(final_probs[1]),
        "snapshots": snapshot_states,
    }
    return f.state_dict(), stats


# ---------------------------
# Full algorithm runner
# ---------------------------
def run_algorithm(
    X: np.ndarray,
    y: np.ndarray,
    epsilon: float,
    save_dir: str,
    reset_each_i: bool = True,
    num_classes: int = 2,
    model_hidden: int = 128,
    model_depth: int = 2,
    dropout: float = 0.0,
    base_train_cfg: TrainConfig = TrainConfig(),
    opt_num_attempts: int = 30, # Number of random restarts for optimization
    ascent_lr: float = 1e-2,
    max_steps: int = 2000,
    eval_every: int = 10,   
    seed: Optional[int] = None,
    shuffle: bool = True,
    shuffle_seed: Optional[int] = None,
    relative_epsilon: bool = False,
    reuse_base_model: bool = False,
    num_data_points: int = -1,
) -> Dict[str, float]:
    os.makedirs(f"{save_dir}/loss_plots", exist_ok=True)
    n, d = X.shape

    if shuffle:
        rng = np.random.default_rng(shuffle_seed)
        perm = rng.permutation(n)
        X = X[perm]
        y = y[perm]

    if seed is not None:
        torch.manual_seed(seed)

    def ctor() -> nn.Module:
        return MLPBinary2Logits(d=d, hidden=model_hidden, depth=model_depth, dropout=dropout)

    # Step 1: Base model and BaseLoss
    best_loss = float('inf')
    best_state = None
    best_accuracy = -1.0

    if not reuse_base_model:
        for attempt in range(opt_num_attempts):
            base_model = ctor()
            base_model, base_loss, train_losses = train_to_optimum(X, y, base_model, base_train_cfg)

            base_accuracy = (
                base_model.predict(torch.from_numpy(X).float().to(base_train_cfg.device)) 
                == torch.from_numpy(y).to(base_train_cfg.device)
            ).float().mean().item()

            print(base_loss, base_accuracy)

            if base_loss < best_loss:
                best_loss = base_loss
                best_state = copy.deepcopy(base_model.state_dict())

                best_accuracy = (
                    base_model.predict(torch.from_numpy(X).float().to(base_train_cfg.device)) 
                    == torch.from_numpy(y).to(base_train_cfg.device)
                ).float().mean().item()
            
            if train_losses:
                plt.figure()
                plt.plot(range(1, len(train_losses) + 1), train_losses)
                plt.title("Training Loss")
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.tight_layout()
                plt.savefig(os.path.join(save_dir, "loss_plots", f"training_loss_attempt={attempt}.png"))
                plt.close()

        del base_model
        del base_loss
        del base_accuracy

        torch.save(
            {
                "state_dict": best_state,
                "base_loss": float(best_loss),
                "epsilon": float(epsilon),
                "reset_each_i": bool(reset_each_i),
            },
            os.path.join(save_dir, "base_model.pt"),
        )
    else:
        data = torch.load(os.path.join(save_dir, "base_model.pt"))
        best_state = data["state_dict"]
        best_loss = data["base_loss"]

    print(best_loss)
    print(best_accuracy)

    if relative_epsilon:
        epsilon = best_loss * epsilon

    current_state = copy.deepcopy(best_state)

    adversarial_inds = []
    if num_data_points != -1:
        rng = np.random.default_rng(seed)
        adversarial_inds = rng.choice(n, size=num_data_points, replace=False).tolist()
    else:
        adversarial_inds = list(range(n))

    # Steps 2-5
    for i in adversarial_inds:
        for target_class in range(num_classes):
            start_state = best_state if reset_each_i else current_state

            state_i, stats_i = maximize_fx_i_with_constraint(
                start_state=start_state,
                X=X,
                y=y,
                model_ctor=ctor,
                i=i,
                base_loss=best_loss,
                epsilon=epsilon,
                target_class=target_class,
                ascent_lr=ascent_lr,
                max_steps=max_steps,
                eval_every=eval_every,
                device=base_train_cfg.device,
            )

            torch.save(
                {"state_dict": state_i, "stats": stats_i, "target_class": target_class},
                os.path.join(save_dir, f"model_i={i}_c={target_class}.pt"),
            )

            if not reset_each_i:
                current_state = copy.deepcopy(state_i)

    return {"base_loss": float(best_loss), "epsilon": float(epsilon), "n": float(n), "d": float(d)}
