import os
import copy
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple, List

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset


# ---------------------------
# Model
# ---------------------------
class MLPBinary2Logits(nn.Module):
    """
    Binary classifier that returns logits for both classes: shape (batch, 2).
    """
    def __init__(self, d: int, hidden: int = 128, depth: int = 2, dropout: float = 0.0):
        super().__init__()
        layers = []
        in_dim = d
        for _ in range(depth):
            layers += [nn.Linear(in_dim, hidden), nn.ReLU()]
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            in_dim = hidden
        layers.append(nn.Linear(in_dim, 2))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)  # (batch, 2)

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """Predict class labels (0 or 1) for a batch."""
        logits = self.forward(x)
        probs = torch.softmax(logits, dim=1)
        return torch.round(probs[:, 1])

    def single_predict(self, x: torch.Tensor) -> torch.Tensor:
        """Predict class label for a single sample."""
        logits = self.forward(x)
        probs = torch.softmax(logits, dim=0)
        return torch.round(probs[1])


# ---------------------------
# Training / Evaluation
# ---------------------------
@dataclass
class TrainConfig:
    batch_size: int = 256
    lr: float = 1e-3
    weight_decay: float = 0.0
    epochs: int = 50
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 0


@torch.no_grad()
def dataset_loss_ce(model: nn.Module, loader: DataLoader, device: str) -> float:
    """
    Mean cross-entropy over the dataset.
    """
    model.eval()
    loss_fn = nn.CrossEntropyLoss(reduction="sum")
    total_loss, total_n = 0.0, 0

    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = loss_fn(logits, yb)
        total_loss += float(loss.item())
        total_n += xb.size(0)

    return total_loss / max(total_n, 1)


def train_to_optimum(X: np.ndarray, y: np.ndarray, model: nn.Module, cfg: TrainConfig) -> Tuple[nn.Module, float]:
    """
    Computes the optimal model
    """
    assert X.ndim == 2, "X must be (n, d)"
    assert y.ndim == 1 and y.shape[0] == X.shape[0], "y must be shape (n,)"
    assert set(np.unique(y)).issubset({0, 1}), "y must be binary {0,1}"

    device = cfg.device
    model = model.to(device)

    X_t = torch.from_numpy(X).float()
    y_t = torch.from_numpy(y).long()
    ds = TensorDataset(X_t, y_t)

    loader = DataLoader(ds, batch_size=cfg.batch_size, shuffle=True, drop_last=False, num_workers=cfg.num_workers)

    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    model.train()
    for _ in range(cfg.epochs):
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits, yb)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

    eval_loader = DataLoader(ds, batch_size=max(cfg.batch_size, 512), shuffle=False, drop_last=False)
    base_loss = dataset_loss_ce(model, eval_loader, device)
    return model, base_loss


# ---------------------------
# Per-point ascent under loss constraint
# ---------------------------
def maximize_fx_i_with_constraint(
    start_state: Dict[str, torch.Tensor],
    X: np.ndarray,
    y: np.ndarray,
    model_ctor: Callable[[], nn.Module],
    i: int,
    base_loss: float,
    epsilon: float,
    target_class: int = 1,
    ascent_lr: float = 1e-2,
    max_steps: int = 2000,
    eval_every: int = 10,
    device: Optional[str] = None,
    prob_objective_eps: float = 1e-12,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, float]]:
    """
    Uses gradient ascent to maximize the post-softmax probability of a given class
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    f = model_ctor().to(device)
    f.load_state_dict(start_state, strict=True)

    X_t = torch.from_numpy(X).float()
    y_t = torch.from_numpy(y).long()
    ds = TensorDataset(X_t, y_t)
    eval_loader = DataLoader(ds, batch_size=512, shuffle=False)

    xi = torch.from_numpy(X[i]).float().unsqueeze(0).to(device)

    opt = torch.optim.SGD(f.parameters(), lr=ascent_lr)

    current_loss = dataset_loss_ce(f, eval_loader, device)
    
    # NEW: Save snapshots during ascent to avoid overshooting
    snapshot_states: List[Dict[str, torch.Tensor]] = []
    snapshot_steps: List[int] = []

    step = 0
    while (current_loss < base_loss + epsilon) and (step < max_steps):
        f.train()
        logits_i = f(xi)                 # (1, 2)
        probs_i = torch.softmax(logits_i, dim=1)  # (1, 2)
        p_t = probs_i[:, target_class].mean()

        # Gradient ascent on probability
        # objective = -p_t

        # If you want a less-saturating objective, use:
        objective = -torch.log(p_t + prob_objective_eps)

        opt.zero_grad(set_to_none=True)
        objective.backward()
        opt.step()

        step += 1
        if step % eval_every == 0:
            current_loss = dataset_loss_ce(f, eval_loader, device)
            # NEW: Save snapshot if still within constraint
            if current_loss < base_loss + epsilon:
                snapshot_states.append(copy.deepcopy(f.state_dict()))
                snapshot_steps.append(step)
    
    # NEW: Use last valid snapshot instead of potentially overshot final state
    if len(snapshot_states) > 0:
        f.load_state_dict(snapshot_states[-1], strict=True)
        print(f"  Point {i}, class {target_class}: Used snapshot at step {snapshot_steps[-1]}/{step}")

    final_loss = dataset_loss_ce(f, eval_loader, device)
    with torch.no_grad():
        logits_i = f(xi).squeeze(0)
        probs_i = torch.softmax(logits_i, dim=0)
        final_logits = logits_i.detach().cpu().numpy()
        final_probs = probs_i.detach().cpu().numpy()

    stats = {
        "i": i,
        "steps": step,
        "base_loss": float(base_loss),
        "base_plus_eps": float(base_loss + epsilon),
        "final_loss": float(final_loss),
        "target_class": int(target_class),
        "final_logit_class0": float(final_logits[0]),
        "final_logit_class1": float(final_logits[1]),
        "final_prob_class0": float(final_probs[0]),
        "final_prob_class1": float(final_probs[1]),
    }
    return f.state_dict(), stats


# ---------------------------
# Full algorithm runner
# ---------------------------
def run_algorithm(
    X: np.ndarray,
    y: np.ndarray,
    epsilon: float,
    save_dir: str,
    reset_each_i: bool = True,
    num_classes: int = 2,
    model_hidden: int = 128,
    model_depth: int = 2,
    dropout: float = 0.0,
    base_train_cfg: TrainConfig = TrainConfig(),
    opt_num_attempts: int = 30,  # NEW: Multiple random restarts for base model
    ascent_lr: float = 1e-2,
    max_steps: int = 2000,
    eval_every: int = 10,
    seed: Optional[int] = None,  # NEW: Set PyTorch random seed
    shuffle: bool = True,  # NEW: Shuffle data order
    shuffle_seed: Optional[int] = None,  # NEW: Reproducible shuffle
    relative_epsilon: bool = False,  # NEW: Use multiplicative epsilon
) -> Dict[str, float]:
    os.makedirs(save_dir, exist_ok=True)
    n, d = X.shape

    # NEW: Shuffle data if requested
    if shuffle:
        rng = np.random.default_rng(shuffle_seed)
        perm = rng.permutation(n)
        X = X[perm]
        y = y[perm]

    # NEW: Set PyTorch seed
    if seed is not None:
        torch.manual_seed(seed)

    def ctor() -> nn.Module:
        return MLPBinary2Logits(d=d, hidden=model_hidden, depth=model_depth, dropout=dropout)

    # Step 1: Base model and BaseLoss - NEW: Multiple random restarts
    best_model = None
    best_loss = float('inf')
    best_state = None

    print(f"Training base model with {opt_num_attempts} random restarts...")
    for attempt in range(opt_num_attempts):
        base_model = ctor()
        base_model, base_loss = train_to_optimum(X, y, base_model, base_train_cfg)

        if attempt % 5 == 0:
            print(f"  Attempt {attempt + 1}/{opt_num_attempts}: loss = {base_loss:.4f}")

        if base_loss < best_loss:
            best_loss = base_loss
            best_model = copy.deepcopy(base_model)
            best_state = copy.deepcopy(base_model.state_dict())

    base_loss = best_loss
    print(f"Best base loss: {base_loss:.4f}")

    # NEW: Apply relative epsilon if requested
    if relative_epsilon:
        epsilon = base_loss * epsilon
        print(f"Using relative epsilon: {epsilon:.4f} ({epsilon/base_loss:.2%} of base loss)")

    torch.save(
        {
            "state_dict": best_state,
            "base_loss": float(base_loss),
            "epsilon": float(epsilon),
            "reset_each_i": bool(reset_each_i),
        },
        os.path.join(save_dir, "base_model.pt"),
    )

    current_state = copy.deepcopy(best_state)

    # Steps 2-5
    for i in range(n):
        for target_class in range(num_classes):
            start_state = base_state if reset_each_i else current_state

            state_i, stats_i = maximize_fx_i_with_constraint(
                start_state=start_state,
                X=X,
                y=y,
                model_ctor=ctor,
                i=i,
                base_loss=base_loss,
                epsilon=epsilon,
                target_class=target_class,
                ascent_lr=ascent_lr,
                max_steps=max_steps,
                eval_every=eval_every,
                device=base_train_cfg.device,
            )

            torch.save(
                {"state_dict": state_i, "stats": stats_i, "target_class": target_class},
                os.path.join(save_dir, f"model_i={i}_c={target_class}.pt"),
            )

            if not reset_each_i:
                current_state = copy.deepcopy(state_i)

    return {"base_loss": float(base_loss), "epsilon": float(epsilon), "n": float(n), "d": float(d)}


# ---------------------------
# Rashomon Set Generation
# ---------------------------
def generate_rashomon_set(
    X: np.ndarray,
    y: np.ndarray,
    epsilon: float,
    num_models: int,
    save_dir: str,
    model_hidden: int = 128,
    model_depth: int = 2,
    dropout: float = 0.0,
    base_train_cfg: TrainConfig = TrainConfig(),
    opt_num_attempts: int = 1,
    ascent_lr: float = 1e-2,
    max_steps: int = 2000,
    eval_every: int = 10,
    diversity_strategy: str = "random_point_class",
    seed: Optional[int] = None,
    shuffle: bool = False,
    shuffle_seed: Optional[int] = None,
    relative_epsilon: bool = False,
) -> Dict[str, float]:
    """
    Generate a Rashomon set: num_models diverse models that satisfy BaseLoss + epsilon constraint.
    
    Args:
        opt_num_attempts: Number of random restarts for base model optimization
        shuffle: Whether to shuffle data before training
        shuffle_seed: Random seed for shuffling
        relative_epsilon: If True, epsilon is multiplicative (epsilon * base_loss), otherwise additive
        diversity_strategy: 
            - "random_point_class": Randomly sample (point, class) pairs to maximize
            - "random_direction": Perturb in random parameter directions
            - "multi_init": Train from multiple random initializations
    """
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
    
    os.makedirs(save_dir, exist_ok=True)
    n, d = X.shape

    # Shuffle data if requested
    if shuffle:
        rng = np.random.default_rng(shuffle_seed)
        perm = rng.permutation(n)
        X = X[perm]
        y = y[perm]

    def ctor() -> nn.Module:
        return MLPBinary2Logits(d=d, hidden=model_hidden, depth=model_depth, dropout=dropout)

    # Train base model with multiple random restarts
    best_loss = float('inf')
    best_state = None
    
    print(f"Training base model with {opt_num_attempts} random restart(s)...")
    for attempt in range(opt_num_attempts):
        base_model = ctor()
        base_model, base_loss = train_to_optimum(X, y, base_model, base_train_cfg)
        
        if attempt % 5 == 0 or opt_num_attempts <= 5:
            print(f"  Attempt {attempt + 1}/{opt_num_attempts}: loss = {base_loss:.4f}")
        
        if base_loss < best_loss:
            best_loss = base_loss
            best_state = copy.deepcopy(base_model.state_dict())
    
    base_loss = best_loss
    base_state = best_state
    print(f"Best base loss: {base_loss:.4f}")
    
    # Apply relative epsilon if requested
    original_epsilon = epsilon
    if relative_epsilon:
        epsilon = base_loss * epsilon
        print(f"Using relative epsilon: {epsilon:.4f} ({original_epsilon:.2%} of base loss)")

    torch.save(
        {
            "state_dict": base_state,
            "base_loss": float(base_loss),
            "epsilon": float(epsilon),
            "original_epsilon": float(original_epsilon),
            "relative_epsilon": bool(relative_epsilon),
            "num_models": int(num_models),
            "diversity_strategy": diversity_strategy,
        },
        os.path.join(save_dir, "rashomon_base.pt"),
    )

    print(f"Generating {num_models} models in Rashomon set...")

    models_generated = 0
    attempts = 0
    max_attempts = num_models * 5

    while models_generated < num_models and attempts < max_attempts:
        attempts += 1
        
        if diversity_strategy == "random_point_class":
            # Randomly select a point and class to maximize
            i = np.random.randint(0, n)
            target_class = np.random.randint(0, 2)
            
            state_i, stats_i = maximize_fx_i_with_constraint(
                start_state=base_state,
                X=X,
                y=y,
                model_ctor=ctor,
                i=i,
                base_loss=base_loss,
                epsilon=epsilon,
                target_class=target_class,
                ascent_lr=ascent_lr,
                max_steps=max_steps,
                eval_every=eval_every,
                device=base_train_cfg.device,
            )
            
        elif diversity_strategy == "multi_init":
            # Train from random initialization
            model = ctor()
            model, model_loss = train_to_optimum(X, y, model, base_train_cfg)
            
            # Check if it satisfies constraint
            if model_loss <= base_loss + epsilon:
                state_i = model.state_dict()
                stats_i = {
                    "model_idx": models_generated,
                    "loss": float(model_loss),
                    "base_loss": float(base_loss),
                    "epsilon": float(epsilon),
                    "strategy": "multi_init",
                }
            else:
                continue
                
        else:
            raise ValueError(f"Unknown diversity_strategy: {diversity_strategy}")
        
        # Save model
        torch.save(
            {"state_dict": state_i, "stats": stats_i, "model_idx": models_generated},
            os.path.join(save_dir, f"rashomon_model_{models_generated}.pt"),
        )
        
        models_generated += 1
        if models_generated % 10 == 0:
            print(f"Generated {models_generated}/{num_models} models")
    
    print(f"Successfully generated {models_generated} models in {attempts} attempts")
    
    return {
        "base_loss": float(base_loss),
        "epsilon": float(epsilon),
        "num_models": int(models_generated),
        "n": int(n),
        "d": int(d),
    }


# ---------------------------
# Ensemble Construction
# ---------------------------
@torch.no_grad()
def ensemble_predict(
    models: List[nn.Module],
    X: torch.Tensor,
    device: str = "cpu",
) -> torch.Tensor:
    """
    Ensemble prediction by averaging probabilities across models.
    
    Returns:
        Averaged probabilities of shape (batch, 2)
    """
    X = X.to(device)
    all_probs = []
    
    for model in models:
        model.eval()
        model.to(device)
        logits = model(X)
        probs = torch.softmax(logits, dim=1)
        all_probs.append(probs)
    
    # Average probabilities
    ensemble_probs = torch.stack(all_probs).mean(dim=0)
    return ensemble_probs


@torch.no_grad()
def evaluate_ensemble(
    models: List[nn.Module],
    X: np.ndarray,
    y: np.ndarray,
    device: str = "cpu",
) -> Dict[str, float]:
    """
    Evaluate ensemble performance.
    
    Returns accuracy, loss, and other metrics.
    """
    X_t = torch.from_numpy(X).float()
    y_t = torch.from_numpy(y).long()
    
    ensemble_probs = ensemble_predict(models, X_t, device)
    
    # Compute accuracy
    predictions = ensemble_probs.argmax(dim=1)
    accuracy = (predictions.cpu() == y_t).float().mean().item()
    
    # Compute loss
    loss_fn = nn.CrossEntropyLoss()
    ensemble_loss = loss_fn(torch.log(ensemble_probs + 1e-12), y_t.to(device)).item()
    
    return {
        "accuracy": float(accuracy),
        "loss": float(ensemble_loss),
        "num_models": len(models),
    }


def create_random_ensemble(
    rashomon_dir: str,
    num_models: int,
    ensemble_size: int,
    model_ctor: Callable[[], nn.Module],
    X: np.ndarray,
    y: np.ndarray,
    device: str = "cpu",
    seed: Optional[int] = None,
) -> Dict[str, float]:
    """
    Create ensemble by randomly sampling from Rashomon set and evaluate.
    
    Args:
        rashomon_dir: Directory containing rashomon_model_*.pt files
        num_models: Total number of models in Rashomon set
        ensemble_size: Number of models to sample for ensemble
        model_ctor: Function to create model instance
        X, y: Test data
        device: Device to run on
        seed: Random seed
    
    Returns:
        Dictionary with ensemble performance metrics
    """
    if seed is not None:
        np.random.seed(seed)
    
    # Randomly sample models
    sampled_indices = np.random.choice(num_models, size=ensemble_size, replace=False)
    
    # Load models
    models = []
    for idx in sampled_indices:
        model = model_ctor()
        model_path = os.path.join(rashomon_dir, f"rashomon_model_{idx}.pt")
        state_dict = torch.load(model_path, map_location=device)["state_dict"]
        model.load_state_dict(state_dict)
        models.append(model)
    
    # Evaluate ensemble
    results = evaluate_ensemble(models, X, y, device)
    results["sampled_indices"] = sampled_indices.tolist()
    
    return results


def ensemble_experiment(
    rashomon_dir: str,
    num_models: int,
    ensemble_sizes: List[int],
    model_ctor: Callable[[], nn.Module],
    X: np.ndarray,
    y: np.ndarray,
    num_trials: int = 10,
    device: str = "cpu",
) -> Dict[int, Dict[str, List[float]]]:
    """
    Run ensemble experiments with different ensemble sizes.
    
    Args:
        ensemble_sizes: List of ensemble sizes to try (e.g., [1, 5, 10, 20, 50])
        num_trials: Number of random samples per ensemble size
    
    Returns:
        Dictionary mapping ensemble_size -> {metric: [values across trials]}
    """
    results = {}
    
    for ens_size in ensemble_sizes:
        print(f"\nEvaluating ensemble size: {ens_size}")
        accuracies = []
        losses = []
        
        for trial in range(num_trials):
            result = create_random_ensemble(
                rashomon_dir=rashomon_dir,
                num_models=num_models,
                ensemble_size=ens_size,
                model_ctor=model_ctor,
                X=X,
                y=y,
                device=device,
                seed=trial,
            )
            accuracies.append(result["accuracy"])
            losses.append(result["loss"])
        
        results[ens_size] = {
            "accuracies": accuracies,
            "losses": losses,
            "mean_accuracy": float(np.mean(accuracies)),
            "std_accuracy": float(np.std(accuracies)),
            "mean_loss": float(np.mean(losses)),
            "std_loss": float(np.std(losses)),
        }
        
        print(f"  Accuracy: {results[ens_size]['mean_accuracy']:.4f} ± {results[ens_size]['std_accuracy']:.4f}")
        print(f"  Loss: {results[ens_size]['mean_loss']:.4f} ± {results[ens_size]['std_loss']:.4f}")
    
    return results


# ---------------------------
# Per-Point Model Integration
# ---------------------------
def load_per_point_models(
    save_dir: str,
    X: np.ndarray,
    model_hidden: int = 128,
    model_depth: int = 2,
    dropout: float = 0.0,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> List[nn.Module]:
    """
    Load all per-point models generated by run_algorithm().
    
    These models are saved as model_i={i}_c={target_class}.pt where:
    - i ranges from 0 to n-1 (each training point)
    - target_class is 0 or 1 (which class to maximize)
    
    Returns:
        List of loaded models (up to n × 2 models)
    """
    n, d = X.shape
    models = []
    
    for i in range(n):
        for target_class in [0, 1]:
            model_path = os.path.join(save_dir, f"model_i={i}_c={target_class}.pt")
            if os.path.exists(model_path):
                model = MLPBinary2Logits(d=d, hidden=model_hidden, depth=model_depth, dropout=dropout)
                checkpoint = torch.load(model_path, map_location=device)
                model.load_state_dict(checkpoint["state_dict"])
                model.to(device)
                model.eval()
                models.append(model)
    
    print(f"Loaded {len(models)} per-point models from {save_dir}")
    return models


def sample_ensemble_from_per_point_models(
    all_models: List[nn.Module],
    ensemble_size: int,
    seed: Optional[int] = None,
) -> List[nn.Module]:
    """
    Sample a subset of per-point models to create an ensemble.
    
    Args:
        all_models: Pool of available models
        ensemble_size: Number of models to sample
        seed: Random seed for reproducibility
    
    Returns:
        List of sampled models
    """
    if seed is not None:
        np.random.seed(seed)
    
    actual_size = min(ensemble_size, len(all_models))
    indices = np.random.choice(len(all_models), size=actual_size, replace=False)
    return [all_models[i] for i in indices]
