# clmea_improved.py
# Improved CLMEA with:
#  - classifier ensemble (prob + uncertainty)
#  - surrogate ensemble (neural nets) with mean+var predictions
#  - learned acquisition network approximating expected HV improvement
#
# Dependencies: numpy, scipy, scikit-learn, matplotlib, torch

import time
import numpy as np
from scipy.stats import qmc
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# ---------------- Config (fast defaults, change for experiments) ----------------
CFG = {
    "D": 30,
    "M": 2,
    "N_init": 30,
    "NP": 20,
    "maxFEs": 80,
    # classifier ensemble
    "clf_ensembles": 3,
    "clf_epochs": 6,
    "clf_batch": 64,
    # surrogate ensemble
    "sur_ensembles": 3,
    "sur_epochs": 30,
    "sur_batch": 64,
    # acquisition net
    "acq_train_epochs": 20,
    "acq_batch": 64,
    # inner search parameters
    "cand_pool": 200,   # candidate pool size for acquisition training/selection
    "hv_gen_max": 6,
    "local_gen_max2": 3,
    # random seeds
    "seed": 42,
    "img_size": 32,
    "num_classes": 100,
    "batch_size": 128
}

# ---------------- Problem (ZDT1) ----------------
def zdt1(x):
    D = x.shape[1]
    f1 = x[:, 0]
    g = 1 + 9.0 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - np.sqrt(f1 / g)
    f2 = g * h
    return np.vstack([f1, f2]).T

# ---------------- Utilities ----------------
def load_dataset():
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                            download=True, transform=transform)
    return torch.utils.data.DataLoader(trainset, batch_size=CFG['batch_size'],
                                     shuffle=True, num_workers=2)
def lhs_samples(n, D, lower=0.0, upper=1.0, seed=None):
    sampler = qmc.LatinHypercube(d=D, seed=seed)
    u = sampler.random(n)
    return qmc.scale(u, lower, upper)

def operator_de(parent1, parent2, parent3, CR=0.5, F=0.5, lower=0.0, upper=1.0):
    N, D = parent1.shape
    site = np.random.rand(N, D) < CR
    offspring = parent1.copy()
    offspring[site] = parent1[site] + F * (parent2[site] - parent3[site])
    offspring = np.clip(offspring, lower, upper)
    mut_mask = np.random.rand(N, D) < 1.0 / D
    if mut_mask.sum() > 0:
        offspring[mut_mask] += np.random.normal(scale=0.03, size=np.sum(mut_mask))
    return np.clip(offspring, lower, upper)

def nondominated_sort(objs):
    N = objs.shape[0]
    dom_count = np.zeros(N, dtype=int)
    dominates = [[] for _ in range(N)]
    for p in range(N):
        for q in range(N):
            if p==q: continue
            less = np.all(objs[p] <= objs[q])
            strict_less = np.any(objs[p] < objs[q])
            if less and strict_less:
                dominates[p].append(q)
            elif np.all(objs[q] <= objs[p]) and np.any(objs[q] < objs[p]):
                dom_count[p] += 1
    current = np.where(dom_count==0)[0].tolist()
    frontno = np.full(N, np.inf)
    fronts = []
    f = 1
    while current:
        fronts.append(current)
        for p in current:
            frontno[p] = f
        Q = []
        for p in current:
            for q in dominates[p]:
                dom_count[q] -= 1
                if dom_count[q]==0:
                    Q.append(q)
        current = Q
        f += 1
    return frontno.astype(int), fronts

def nondominated_frontpoints(objs):
    frontno, _ = nondominated_sort(objs)
    return objs[frontno==1]

def hv_2d(points, ref=(1.1,1.1)):
    # exact 2D HV (minimization)
    if len(points)==0: return 0.0
    pts = np.array(points)
    front = nondominated_frontpoints(pts)
    if len(front) == 0: return 0.0
    front = front[np.argsort(front[:,0])]
    hv = 0.0
    prev = ref[0]
    for p in reversed(front):
        f1, f2 = p
        width = prev - f1
        height = ref[1] - f2
        if width>0 and height>0:
            hv += width * height
        prev = f1
    return hv

def crowding_distance(objs, front_idx):
    if len(front_idx)==0: return np.array([])
    F = objs[front_idx]
    N, M = F.shape
    cd = np.zeros(N)
    for m in range(M):
        order = np.argsort(F[:, m])
        cd[order[0]] = np.inf
        cd[order[-1]] = np.inf
        fmin = F[order[0], m]; fmax = F[order[-1], m]
        if fmax == fmin: continue
        for i in range(1, N-1):
            cd[order[i]] += (F[order[i+1], m] - F[order[i-1], m]) / (fmax - fmin)
    return cd

# ---------------- PyTorch small MLP building blocks ----------------
def mlp_block(input_dim, hidden=64, n_out=1, n_layers=2, out_activation=None):
    layers = []
    d = input_dim
    for i in range(n_layers):
        layers.append(nn.Linear(d, hidden))
        layers.append(nn.ReLU())
        d = hidden
    layers.append(nn.Linear(d, n_out))
    if out_activation == 'softmax':
        layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

# ---------------- Classifier ensemble ----------------
class ClassifierNet(nn.Module):
    def __init__(self, D, hidden=64, n_classes=4):
        super().__init__()
        self.net = mlp_block(D, hidden=hidden, n_out=n_classes, n_layers=2)
    def forward(self, x):
        return self.net(x)

# ---------------- Surrogate regressor ----------------
class SurrogateNet(nn.Module):
    def __init__(self, D, hidden=128):
        super().__init__()
        self.net = mlp_block(D, hidden=hidden, n_out=1, n_layers=3)
    def forward(self, x):
        return self.net(x).squeeze(-1)

# ---------------- Acquisition net ----------------
class AcquisitionNet(nn.Module):
    def __init__(self, feat_dim, hidden=128):
        super().__init__()
        self.net = mlp_block(feat_dim, hidden=hidden, n_out=1, n_layers=2)
    def forward(self, x):
        return self.net(x).squeeze(-1)

# ---------------- Main improved CLMEA ----------------
class CLMEA_Improved:
    def __init__(self, cfg):
        self.cfg = cfg
        self.D = cfg["D"]
        self.M = cfg["M"]
        self.N_init = cfg["N_init"]
        self.NP = cfg["NP"]
        self.maxFEs = cfg["maxFEs"]
        self.seed = cfg["seed"]
        np.random.seed(self.seed); torch.manual_seed(self.seed)
        # archives
        self.archive_X = None
        self.archive_Y = None
        # scalers
        self.x_scaler = StandardScaler()
        # logs
        self.clf_loss_log = []
        self.clf_acc_log = []
        self.sur_mse_log = []
        self.acq_loss_log = []

        # models placeholders
        self.clf_ensemble = []
        self.sur_ensemble = []  # list of lists: for each objective, a list of nets
        self.acq_net = None

    def initial_sampling(self):
        X0 = lhs_samples(self.N_init, self.D, seed=self.seed)
        Y0 = zdt1(X0)
        self.archive_X = X0.copy()
        self.archive_Y = Y0.copy()
        self.FEs = self.N_init
        # fit scaler
        self.x_scaler.fit(self.archive_X)

    # ---- classifier ensemble training ----
    def train_classifier_ensemble(self, X, Y):
        n_ens = self.cfg["clf_ensembles"]
        device = torch.device("cpu")
        self.clf_ensemble = []
        Xs = self.x_scaler.transform(X)
        Xt = torch.tensor(Xs, dtype=torch.float32)
        frontno, _ = nondominated_sort(Y)
        ranks = np.minimum(frontno, 4).astype(int) - 1
        yt = torch.tensor(ranks, dtype=torch.long)
        for k in range(n_ens):
            net = ClassifierNet(self.D, hidden=64, n_classes=4).to(device)
            opt = optim.Adam(net.parameters(), lr=1e-3)
            loss_fn = nn.CrossEntropyLoss()
            epochs = self.cfg["clf_epochs"]
            batch = self.cfg["clf_batch"]
            for ep in range(epochs):
                perm = torch.randperm(len(Xt))
                epoch_loss = 0.0
                correct = 0
                for i in range(0, len(Xt), batch):
                    idx = perm[i:i+batch]
                    xb = Xt[idx]
                    yb = yt[idx]
                    logits = net(xb)
                    loss = loss_fn(logits, yb)
                    opt.zero_grad(); loss.backward(); opt.step()
                    epoch_loss += loss.item() * len(idx)
                    correct += (logits.argmax(dim=1) == yb).sum().item()
                epoch_loss /= len(Xt)
                acc = correct / len(Xt)
                # log only from first ensemble member to avoid duplicating
                if k==0:
                    self.clf_loss_log.append(epoch_loss)
                    self.clf_acc_log.append(acc)
            self.clf_ensemble.append(net)
        return

    def clf_predict_proba_and_uncertainty(self, X_candidates):
        # returns mean_prob (N, n_classes), entropy (N,), epistemic_uncertainty via ensemble variance
        if len(self.clf_ensemble)==0:
            raise RuntimeError("Classifier not trained")
        Xs = self.x_scaler.transform(X_candidates)
        Xt = torch.tensor(Xs, dtype=torch.float32)
        logits_all = []
        with torch.no_grad():
            for net in self.clf_ensemble:
                logits = net(Xt).numpy()  # shape (N, C)
                probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
                logits_all.append(probs)
        probs_stack = np.stack(logits_all, axis=0)  # (ens, N, C)
        mean_probs = np.mean(probs_stack, axis=0)
        # entropy of mean probs
        eps = 1e-12
        entropy = -np.sum(mean_probs * np.log(mean_probs + eps), axis=1)
        # epistemic: variance across ensemble in prob for predicted class
        var = np.var(probs_stack, axis=0).mean(axis=1)  # average variance across classes
        return mean_probs, entropy, var

    # ---- surrogate ensemble training (per objective) ----
    def train_surrogate_ensemble(self, X, Y):
        n_ens = self.cfg["sur_ensembles"]
        self.sur_ensemble = []
        Xs = self.x_scaler.transform(X)
        Xt = torch.tensor(Xs, dtype=torch.float32)
        device = torch.device("cpu")
        for m in range(Y.shape[1]):
            y = Y[:, m]
            members = []
            for e in range(n_ens):
                net = SurrogateNet(self.D, hidden=128).to(device)
                opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)
                loss_fn = nn.MSELoss()
                epochs = max(10, self.cfg["sur_epochs"])  # keep not too small
                batch = self.cfg["sur_batch"]
                yt = torch.tensor(y, dtype=torch.float32)
                for ep in range(epochs):
                    perm = torch.randperm(len(Xt))
                    epoch_loss = 0.0
                    for i in range(0, len(Xt), batch):
                        idx = perm[i:i+batch]
                        xb = Xt[idx]
                        yb = yt[idx]
                        pred = net(xb)
                        loss = loss_fn(pred, yb)
                        opt.zero_grad(); loss.backward(); opt.step()
                        epoch_loss += loss.item() * len(idx)
                    epoch_loss /= len(Xt)
                members.append(net)
            self.sur_ensemble.append(members)
        # compute ensemble training MSE (mean across objectives)
        preds = []
        with torch.no_grad():
            for m in range(Y.shape[1]):
                members = self.sur_ensemble[m]
                preds_m = np.vstack([mem(torch.tensor(Xs, dtype=torch.float32)).numpy() for mem in members]).T
                preds.append(preds_m.mean(axis=1))
        preds = np.vstack(preds).T
        mse = np.mean((preds - Y)**2)
        self.sur_mse_log.append(mse)
        return

    def surrogate_predict_ensemble(self, X_candidates, return_samples=False):
        # returns mean (N,M) and var (N,M). If return_samples True returns samples shape (ens, N, M)
        Xs = self.x_scaler.transform(X_candidates)
        N = Xs.shape[0]
        M = self.M
        ens = self.cfg["sur_ensembles"]
        samples = np.zeros((ens, N, M))
        with torch.no_grad():
            for m in range(M):
                for e, net in enumerate(self.sur_ensemble[m]):
                    pred = net(torch.tensor(Xs, dtype=torch.float32)).numpy()
                    samples[e, :, m] = pred
        mean = samples.mean(axis=0)
        var = samples.var(axis=0)
        if return_samples:
            return mean, var, samples
        return mean, var

    # ---- learned acquisition: approximate expected HV improvement computed from surrogate samples ----
    def compute_expected_hv_improvement_from_samples(self, samples):
        # samples shape: (ens, N, M) -> compute HV improvement per candidate by averaging improvement across ensemble members
        ens, N, M = samples.shape
        current_front = self.archive_Y
        ref = (1.1, 1.1)
        base_hv = hv_2d(current_front, ref=ref)
        improvements = np.zeros(N)
        for e in range(ens):
            vals = samples[e, :, :]  # (N, M)
            # for each candidate, compute hv of archive + candidate
            for i in range(N):
                new_archive = np.vstack([current_front, vals[i:i+1, :]])
                new_hv = hv_2d(new_archive, ref=ref)
                improvements[i] += (new_hv - base_hv)
        improvements /= ens
        return improvements  # shape (N,)

    def train_acquisition_net(self, candidate_X_pool, acq_targets):
        # candidate_X_pool: (N, D) ; acq_targets: (N,)
        device = torch.device("cpu")
        # build features: for each candidate, use surrogate mean,var, classifier prob for class0, classifier entropy, ensemble var
        mean, var, samples = self.surrogate_predict_ensemble(candidate_X_pool, return_samples=True)
        probs, entropy, epvar = self.clf_predict_proba_and_uncertainty(candidate_X_pool)
        # features: [mean_flat, var_flat, prob0, entropy, epvar]
        N = candidate_X_pool.shape[0]
        feat = np.hstack([mean, var, probs[:, :1], entropy.reshape(-1,1), epvar.reshape(-1,1)])
        feat_t = torch.tensor(feat, dtype=torch.float32)
        target_t = torch.tensor(acq_targets, dtype=torch.float32)
        self.acq_net = AcquisitionNet(feat.shape[1], hidden=128)
        opt = optim.Adam(self.acq_net.parameters(), lr=1e-3)
        loss_fn = nn.MSELoss()
        batch = self.cfg["acq_batch"]
        epochs = self.cfg["acq_train_epochs"]
        for ep in range(epochs):
            perm = torch.randperm(len(feat_t))
            epoch_loss = 0.0
            for i in range(0, len(feat_t), batch):
                idx = perm[i:i+batch]
                xb = feat_t[idx]
                yb = target_t[idx]
                pred = self.acq_net(xb)
                loss = loss_fn(pred, yb)
                opt.zero_grad(); loss.backward(); opt.step()
                epoch_loss += loss.item() * len(idx)
            epoch_loss /= len(feat_t)
            self.acq_loss_log.append(epoch_loss)
        return

    def acquisition_score(self, X_cands):
        # compute learned acquisition score if trained else fallback to surrogate-based expected hv improvement
        if self.acq_net is None:
            # fallback: compute expected hv improvement by sampling surrogate ensemble directly
            mean, var, samples = self.surrogate_predict_ensemble(X_cands, return_samples=True)
            return self.compute_expected_hv_improvement_from_samples(samples)
        else:
            mean, var, samples = self.surrogate_predict_ensemble(X_cands, return_samples=True)
            probs, entropy, epvar = self.clf_predict_proba_and_uncertainty(X_cands)
            feat = np.hstack([mean, var, probs[:, :1], entropy.reshape(-1,1), epvar.reshape(-1,1)])
            with torch.no_grad():
                s = self.acq_net(torch.tensor(feat, dtype=torch.float32)).numpy()
            return s

    # ---- candidate generation and selection using combined score ----
    def generate_candidate_pool(self, pool_size):
        # sample random + DE-based offsprings from archive
        if len(self.archive_X) < 3:
            return lhs_samples(pool_size, self.D, seed=None)
        # pick some random
        rand_part = int(pool_size * 0.4)
        pool = []
        if rand_part>0:
            pool.append(lhs_samples(rand_part, self.D, seed=None))
        # DE offspring from selected parents from archive
        sel_idx = np.random.choice(len(self.archive_X), size=(pool_size - rand_part, 3), replace=True)
        parent1 = self.archive_X[sel_idx[:,0]]
        parent2 = self.archive_X[sel_idx[:,1]]
        parent3 = self.archive_X[sel_idx[:,2]]
        offs = operator_de(parent1, parent2, parent3)
        pool.append(offs)
        pool = np.vstack(pool)
        return pool[:pool_size]

    def select_candidates(self, n_select=1):
        # build candidate pool
        pool = self.generate_candidate_pool(self.cfg["cand_pool"])
        # classifier proba + uncertainty
        probs, entropy, epvar = self.clf_predict_proba_and_uncertainty(pool)
        # predicted prob of best class (class 0) and entropy/epvar
        prob0 = probs[:,0]
        # surrogate-based acquisition score (learned or direct)
        acq_score = self.acquisition_score(pool)
        # combined score: prefer high predicted Pareto prob, times acquisition score scaled, times uncertainty
        # normalize components
        eps = 1e-12
        p0_n = (prob0 - prob0.min()) / (prob0.max() - prob0.min() + eps)
        acq_n = (acq_score - acq_score.min()) / (acq_score.max() - acq_score.min() + eps)
        ent_n = (entropy - entropy.min()) / (entropy.max() - entropy.min() + eps)
        combined = p0_n * (0.6 * acq_n + 0.4 * ent_n)
        idx_sorted = np.argsort(combined)[::-1]
        chosen = pool[idx_sorted[:n_select]]
        return chosen

    # ---- main loop ----
    def run(self):
        print("Starting improved CLMEA run. CFG:", self.cfg)
        t0 = time.time()
        self.initial_sampling()
        iter_no = 0
        while self.FEs < self.maxFEs:
            iter_no += 1
            print(f"\n=== Iter {iter_no} | FEs: {self.FEs}/{self.maxFEs} | archive size: {len(self.archive_X)} ===")
            # 1) train classifier ensemble
            t1 = time.time()
            self.train_classifier_ensemble(self.archive_X, self.archive_Y)
            print(f" Trained classifier ensemble ({len(self.clf_ensemble)} members) in {time.time()-t1:.2f}s")
            # 2) train surrogate ensemble
            t2 = time.time()
            self.train_surrogate_ensemble(self.archive_X, self.archive_Y)
            print(f" Trained surrogate ensemble (per-objective members={self.cfg['sur_ensembles']}) in {time.time()-t2:.2f}s; surrogate MSE {self.sur_mse_log[-1]:.4e}")
            # 3) prepare acquisition training data (optional) - compute expected HV improvement for a pool and train acq net
            pool = self.generate_candidate_pool(self.cfg["cand_pool"])
            mean, var, samples = self.surrogate_predict_ensemble(pool, return_samples=True)
            acq_targets = self.compute_expected_hv_improvement_from_samples(samples)
            # normalize targets
            if np.all(acq_targets==0):
                acq_targets = acq_targets
            else:
                acq_targets = (acq_targets - acq_targets.min())/(acq_targets.max()-acq_targets.min()+1e-12)
            t3 = time.time()
            self.train_acquisition_net(pool, acq_targets)
            print(f" Trained acquisition net in {time.time()-t3:.2f}s (loss last {self.acq_loss_log[-1]:.4e})")
            # 4) select candidate via combined score
            x_sel = self.select_candidates(n_select=1)
            if x_sel.shape[0] > 0:
                y_new = zdt1(x_sel)
                self.archive_X = np.vstack([self.archive_X, x_sel])
                self.archive_Y = np.vstack([self.archive_Y, y_new])
                self.FEs += x_sel.shape[0]
                print(f" Evaluated {x_sel.shape[0]} new points.")
            else:
                print(" No candidate selected.")
            # safe break
            if iter_no > 1000:
                break
        print("Finished run: FEs =", self.FEs, " elapsed:", time.time()-t0)
        return

    # plotting logs
    def plot_metrics(self):
        if len(self.clf_loss_log)>0:
            fig, ax = plt.subplots(1,2, figsize=(12,4))
            ax[0].plot(self.clf_loss_log); ax[0].set_title("Classifier train loss")
            ax[1].plot(self.clf_acc_log); ax[1].set_title("Classifier train acc")
            plt.show()
        if len(self.sur_mse_log)>0:
            plt.figure(); plt.plot(self.sur_mse_log); plt.title("Surrogate MSE"); plt.show()
        if len(self.acq_loss_log)>0:
            plt.figure(); plt.plot(self.acq_loss_log); plt.title("Acquisition train loss"); plt.show()
# --------------- 训练逻辑 ---------------
def train_epoch(model, train_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return running_loss/len(train_loader), correct/total
# --------------- 模型定义 ---------------
class CNNClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, CFG['num_classes'])
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
# ---------------- run example ----------------
if __name__ == "__main__":
    cl = CLMEA_Improved(CFG)
    cl.run()
    cl.plot_metrics()
    frontno, _ = nondominated_sort(cl.archive_Y)
    pareto = cl.archive_Y[frontno==1]
    print("Pareto size:", pareto.shape[0])