# -*- coding: utf-8 -*-
"""idbdisc_CHD49.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/18GapnUIHXCMpCXIQ1Aw_OfyyEwYINN7P
"""

#iDB-PD, AD (CHD49)
!pip -q install matplotlib numpy torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install git+https://github.com/cooper-org/cooper

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
import os
import time
from google.colab import files

import cooper
from cooper import ConstrainedMinimizationProblem, Constraint, ConstraintType
from cooper.optim import SimultaneousOptimizer
from cooper.multipliers import DenseMultiplier

#load in data via google colab, change as needed
from google.colab import drive
drive.mount('/content/drive')

mat_path = '/content/drive/MyDrive/CHD_49.mat'

# Load the .mat file
mat = scipy.io.loadmat(mat_path)

torch.manual_seed(2025)
np.random.seed(2025)

#helper function for saving plots
def save_and_show(fname: str):
    # Save only (no plt.show, no per-file download)
    os.makedirs("plots_ad", exist_ok=True)
    out_path = os.path.join("plots_ad", fname)
    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()  # free the figure so Colab doesn't keep them open
    print(f"Saved: {out_path}")


X = mat['data']
Y = mat['targets']

Y_binary = (Y + 1) // 2


label_A = np.argmax(Y_binary[:, :3], axis=1)
label_B = np.argmax(Y_binary[:, 3:], axis=1)
n_samples = label_A.shape[0]
m_samples = label_A.shape[0]
# Convert to PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
label_A_tensor = torch.tensor(label_A, dtype=torch.long)
label_B_tensor = torch.tensor(label_B, dtype=torch.long)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X = torch.tensor(X, dtype=torch.float32).to(device)

y_true = torch.tensor(label_A, dtype=torch.long).to(device)  # class labels

X_const = torch.tensor(X, dtype=torch.float32).to(device)

y_true_const = torch.tensor(label_B, dtype=torch.long).to(device)





n_classes = 3      # Number of classes for classification
input_dim = X.shape[1]     # Feature dimension
hidden_dim = 60    # Hidden layer size
lambda_reg = 1e-3  # Regularization weight
gamma0 = 0.001
gamma = 0.001/22     # Step size
alpha = 0.25/gamma
epochs = 3000
sigma = 0.4/(lambda_reg*max(m_samples,n_samples))


# Model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)

model0 = SimpleNet().to(device)
# Sample weights y on the simplex
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples


# Optimizers
loss_history0 = []
gradpsi_history = []
# Projection onto simplex (Euclidean)
def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:#(v):
    return torch.softmax(z, dim = dim) #w

# Pre-training to find the threshold r
for epoch in range(int(epochs/2)):
    model0.train()

    # Forward pass for constraint
    logitss = model0(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const

    # Backward for x
    grad_const_x = torch.autograd.grad(total_loss_const, model0.parameters(), retain_graph=False, create_graph=False)

    with torch.no_grad():
        for param, grad2 in zip(model0.parameters(), grad_const_x):
            param -= gamma0 * grad2

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history0.append(total_loss_const.item())
    gradpsi_history.append(torch.norm(grad2))

    if epoch % 100 == 0 or epoch == int(epochs/2) - 1:
        print(f"Epoch {epoch}: loss = {total_loss_const.item():.4f} norm_grad = {torch.norm(grad2):.4f}")  #max(0,total_loss_const.item())

r = np.abs(total_loss_const.item())
model = SimpleNet().to(device)
# Sample weights y on the simplex — treat as parameter
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples


# Optimizers
loss_history = []
dk_history = []
infeas_history = []
slack_history = []

#idb-PD
for epoch in range(epochs):
    model.train()
    #optimizer_y.zero_grad()

    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        norm_grad2 = torch.norm(grad_const_vector)
        if max(0, total_loss_const.item()) * norm_grad2 > 0:
            lam = torch.max(-torch.dot(grad_const_vector , grad_obj_vector) + alphak * norm_grad2, 0)
            lam = lam[0]
            lam_x_grad2 = (lam * grad_const_vector) / (norm_grad2 ** 2)
        else:
            lam = torch.tensor(0.0)
            lam_x_grad2 = torch.zeros_like(grad_const_vector)
        updated_vector = param_vector - gamma * (grad_obj_vector + lam_x_grad2)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history.append(total_loss.item())
    dk_history.append(torch.norm(grad_obj_vector + lam_x_grad2).item())
    infeas_history.append(max(0,total_loss_const.item()))
    slack_history.append(torch.abs(lam * total_loss_const.item()).item())

    if epoch % 100 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + lam_x_grad2):.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")




#adaptive discretization (COOPER) baseline
class DiscreteMinimaxCMP(ConstrainedMinimizationProblem):
    def __init__(self, model, X, y_true, X_const, y_const, Y_set, W_set, r):
        super().__init__()
        self.model = model
        self.X, self.y_true = X, y_true
        self.Xc, self.yc = X_const, y_const
        self.Y_set, self.W_set = Y_set, W_set
        self.r = r
        self.constraint = Constraint(
            multiplier=DenseMultiplier(num_constraints=1, device=X.device),
            constraint_type=ConstraintType.INEQUALITY
        )
    def compute_cmp_state(self, model=None, inputs=None, targets=None):
        logits   = self.model(self.X)
        phi_vals = F.cross_entropy(logits, self.y_true, reduction='none')
        phi      = torch.max(phi_vals[self.Y_set])
        logits_c = self.model(self.Xc)
        psi_vals = F.cross_entropy(logits_c, self.yc,     reduction='none')
        psi      = torch.max(psi_vals[self.W_set]) - self.r
        constraint_state = cooper.ConstraintState(violation=psi)
        return cooper.CMPState(loss=phi, observed_constraints={self.constraint: constraint_state})

def adaptive_discretization_solver(model, X, y_true, Xc, yc, r, max_iters=3000, eps=1e-3, lr=1e-3):
    n_samples, m_samples = X.shape[0], Xc.shape[0]
    dk_history_c, infeas_history_c, slack_history_c, loss_history_c = [], [], [], []
    Y_set = [torch.randint(0, n_samples, (1,)).item()]
    W_set = [torch.randint(0, m_samples, (1,)).item()]
    history = []
    for k in range(max_iters):
        cmp = DiscreteMinimaxCMP(model, X, y_true, Xc, yc, Y_set, W_set, r)
        primal_opt = torch.optim.Adam(model.parameters(), lr=lr)
        dual_opt   = torch.optim.SGD(cmp.dual_parameters(), lr=lr, maximize=True)
        coop_opt   = SimultaneousOptimizer(cmp, primal_opt, dual_opt)
        coop_opt.roll(compute_cmp_state_kwargs={})

        model.eval()
        with torch.no_grad():
            full_loss  = F.cross_entropy(model(X),  y_true, reduction='none')
            full_const = F.cross_entropy(model(Xc), yc,     reduction='none')

        y_new  = torch.argmax(full_loss).item()
        w_new  = torch.argmax(full_const).item()
        phi_new = full_loss[y_new].item()
        psi_new = full_const[w_new].item()

        if y_new not in Y_set: Y_set.append(y_new)
        if w_new not in W_set: W_set.append(w_new)

        phi_old = torch.max(full_loss[Y_set]).item()
        psi_old = torch.max(full_const[W_set]).item()

        with torch.no_grad():
            grads = [p.grad.view(-1) for p in model.parameters() if p.grad is not None]
            dk = torch.norm(torch.cat(grads)).item() if grads else 0.0

        infeas = max(0.0, psi_old - r)
        lambda_k = cmp.constraint.multiplier().item()
        slack = abs(lambda_k * (psi_old - r))

        dk_history_c.append(dk)
        infeas_history_c.append(infeas)
        slack_history_c.append(slack)
        loss_history_c.append(phi_old)

        outer_phi = phi_new - phi_old
        outer_psi = psi_new - r

        if k % 1000 == 0:
            print(f"[Adaptive Discretization] Iter {k}: phi_max={phi_old:.4f}, psi_max={psi_old:.4f}, "
                  f"outer_phi={outer_phi:.4f}, outer_psi={outer_psi:.4f}", flush=True)

        if outer_phi <= eps and outer_psi <= eps:
            print("[Adaptive Discretization] Converged at iteration", k)
            break

    return model, Y_set, W_set, dk_history_c, infeas_history_c, slack_history_c, loss_history_c

model_ad = SimpleNet().to(device)
model_ad, Y_set, W_set, dk_history_ad, infeas_history_ad, slack_history_ad, loss_history_ad = adaptive_discretization_solver(
    model_ad, X, y_true, X_const, y_true_const, r, max_iters=3000, eps=1e-3, lr=1e-3
)


#plots
def plot_series(metric_name, ys, labels, yscale='log', ylabel=None, fname=None):
    plt.figure()
    for y, lbl in zip(ys, labels):
        x = range(len(y))
        plt.plot(x, y, label=lbl, linewidth=2.0)
    if yscale:
        plt.yscale(yscale)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel(ylabel or metric_name, fontsize=14)
    plt.grid(True, which="both" if yscale == 'log' else "major", linestyle="--", alpha=0.7)
    plt.legend()
    save_and_show(fname or f"{metric_name.lower().replace(' ', '_')}_ad.png")

# Infeasibility
plot_series(
    "Infeasibility",
    [infeas_history, infeas_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Infeasibility ($[\psi(x_k,w_k)]_+$)',
    fname="infeasibility_ad.png"
)

# Stationarity
plot_series(
    "Stationarity",
    [dk_history, dk_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Stationarity ($\|d_k\|$)',
    fname="stationarity_ad.png"
)

# Slackness
plot_series(
    "Slackness",
    [slack_history, slack_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Slackness ($|\lambda_k \cdot \psi(x_k,w_k)|$)',
    fname="slackness_ad.png"
)

# Objective Training Loss
plot_series(
    "Objective Training Loss",
    [loss_history, loss_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale=None,
    ylabel="Objective Training Loss",
    fname="objective_loss_ad.png"
)

# Zip & download all plots together
!zip -r -q plots_ad.zip plots_ad
files.download('plots_ad.zip')