# clmea_improved.py
# Improved CLMEA with:
#  - classifier ensemble (prob + uncertainty)
#  - surrogate ensemble (neural nets) with mean+var predictions
#  - learned acquisition network approximating expected HV improvement
#
# Dependencies: numpy, scipy, scikit-learn, matplotlib, torch

import time
import numpy as np
from scipy.stats import qmc
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

# ---------------- Config (fast defaults, change for experiments) ----------------
CFG = {
    "D": 30,
    "M": 2,
    "N_init": 30,
    "NP": 20,
    "maxFEs": 80,
    # classifier ensemble
    "clf_ensembles": 3,
    "clf_epochs": 6,
    "clf_batch": 64,
    # surrogate ensemble
    "sur_ensembles": 3,
    "sur_epochs": 30,
    "sur_batch": 64,
    # acquisition net
    "acq_train_epochs": 20,
    "acq_batch": 64,
    # inner search parameters
    "cand_pool": 200,   # candidate pool size for acquisition training/selection
    "hv_gen_max": 6,
    "local_gen_max2": 3,
    # random seeds
    "seed": 42,
}

# ---------------- Problem (ZDT1) ----------------
def zdt1(x):
    D = x.shape[1]
    f1 = x[:, 0]
    g = 1 + 9.0 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - np.sqrt(f1 / g)
    f2 = g * h
    return np.vstack([f1, f2]).T

# ---------------- Utilities ----------------
def lhs_samples(n, D, lower=0.0, upper=1.0, seed=None):
    sampler = qmc.LatinHypercube(d=D, seed=seed)
    u = sampler.random(n)
    return qmc.scale(u, lower, upper)

def operator_de(parent1, parent2, parent3, CR=0.5, F=0.5, lower=0.0, upper=1.0):
    N, D = parent1.shape
    site = np.random.rand(N, D) < CR
    offspring = parent1.copy()
    offspring[site] = parent1[site] + F * (parent2[site] - parent3[site])
    offspring = np.clip(offspring, lower, upper)
    mut_mask = np.random.rand(N, D) < 1.0 / D
    if mut_mask.sum() > 0:
        offspring[mut_mask] += np.random.normal(scale=0.03, size=np.sum(mut_mask))
    return np.clip(offspring, lower, upper)

def nondominated_sort(objs):
    N = objs.shape[0]
    dom_count = np.zeros(N, dtype=int)
    dominates = [[] for _ in range(N)]
    for p in range(N):
        for q in range(N):
            if p==q: continue
            less = np.all(objs[p] <= objs[q])
            strict_less = np.any(objs[p] < objs[q])
            if less and strict_less:
                dominates[p].append(q)
            elif np.all(objs[q] <= objs[p]) and np.any(objs[q] < objs[p]):
                dom_count[p] += 1
    current = np.where(dom_count==0)[0].tolist()
    frontno = np.full(N, np.inf)
    fronts = []
    f = 1
    while current:
        fronts.append(current)
        for p in current:
            frontno[p] = f
        Q = []
        for p in current:
            for q in dominates[p]:
                dom_count[q] -= 1
                if dom_count[q]==0:
                    Q.append(q)
        current = Q
        f += 1
    return frontno.astype(int), fronts

def nondominated_frontpoints(objs):
    frontno, _ = nondominated_sort(objs)
    return objs[frontno==1]

def hv_2d(points, ref=(1.1,1.1)):
    # exact 2D HV (minimization)
    if len(points)==0: return 0.0
    pts = np.array(points)
    front = nondominated_frontpoints(pts)
    if len(front) == 0: return 0.0
    front = front[np.argsort(front[:,0])]
    hv = 0.0
    prev = ref[0]
    for p in reversed(front):
        f1, f2 = p
        width = prev - f1
        height = ref[1] - f2
        if width>0 and height>0:
            hv += width * height
        prev = f1
    return hv

def crowding_distance(objs, front_idx):
    if len(front_idx)==0: return np.array([])
    F = objs[front_idx]
    N, M = F.shape
    cd = np.zeros(N)
    for m in range(M):
        order = np.argsort(F[:, m])
        cd[order[0]] = np.inf
        cd[order[-1]] = np.inf
        fmin = F[order[0], m]; fmax = F[order[-1], m]
        if fmax == fmin: continue
        for i in range(1, N-1):
            cd[order[i]] += (F[order[i+1], m] - F[order[i-1], m]) / (fmax - fmin)
    return cd

# ---------------- PyTorch small MLP building blocks ----------------
def mlp_block(input_dim, hidden=64, n_out=1, n_layers=2, out_activation=None):
    layers = []
    d = input_dim
    for i in range(n_layers):
        layers.append(nn.Linear(d, hidden))
        layers.append(nn.ReLU())
        d = hidden
    layers.append(nn.Linear(d, n_out))
    if out_activation == 'softmax':
        layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

# ---------------- Classifier ensemble ----------------
class ClassifierNet(nn.Module):
    def __init__(self, D, hidden=64, n_classes=4):
        super().__init__()
        self.net = mlp_block(D, hidden=hidden, n_out=n_classes, n_layers=2)
    def forward(self, x):
        return self.net(x)

# ---------------- Surrogate regressor ----------------
class SurrogateNet(nn.Module):
    def __init__(self, D, hidden=128):
        super().__init__()
        self.net = mlp_block(D, hidden=hidden, n_out=1, n_layers=3)
    def forward(self, x):
        return self.net(x).squeeze(-1)

# ---------------- Acquisition net ----------------
class AcquisitionNet(nn.Module):
    def __init__(self, feat_dim, hidden=128):
        super().__init__()
        self.net = mlp_block(feat_dim, hidden=hidden, n_out=1, n_layers=2)
    def forward(self, x):
        return self.net(x).squeeze(-1)

# ---------------- Main improved CLMEA ----------------
class CLMEA_Improved:
    def __init__(self, cfg):
        self.cfg = cfg
        self.D = cfg["D"]
        self.M = cfg["M"]
        self.N_init = cfg["N_init"]
        self.NP = cfg["NP"]
        self.maxFEs = cfg["maxFEs"]
        self.seed = cfg["seed"]
        np.random.seed(self.seed); torch.manual_seed(self.seed)
        # archives
        self.archive_X = None
        self.archive_Y = None
        # scalers
        self.x_scaler = StandardScaler()
        # logs
        self.clf_loss_log = []
        self.clf_acc_log = []
        self.sur_mse_log = []
        self.acq_loss_log = []

        # models placeholders
        self.clf_ensemble = []
        self.sur_ensemble = []  # list of lists: for each objective, a list of nets
        self.acq_net = None

    def initial_sampling(self):
        X0 = lhs_samples(self.N_init, self.D, seed=self.seed)
        Y0 = zdt1(X0)
        self.archive_X = X0.copy()
        self.archive_Y = Y0.copy()
        self.FEs = self.N_init
        # fit scaler
        self.x_scaler.fit(self.archive_X)

    # ---- classifier ensemble training ----
    def train_classifier_ensemble(self, X, Y):
        n_ens = self.cfg["clf_ensembles"]
        device = torch.device("cpu")
        self.clf_ensemble = []
        Xs = self.x_scaler.transform(X)
        Xt = torch.tensor(Xs, dtype=torch.float32)
        frontno, _ = nondominated_sort(Y)
        ranks = np.minimum(frontno, 4).astype(int) - 1
        yt = torch.tensor(ranks, dtype=torch.long)
        for k in range(n_ens):
            net = ClassifierNet(self.D, hidden=64, n_classes=4).to(device)
            opt = optim.Adam(net.parameters(), lr=1e-3)
            loss_fn = nn.CrossEntropyLoss()
            epochs = self.cfg["clf_epochs"]
            batch = self.cfg["clf_batch"]
            for ep in range(epochs):
                perm = torch.randperm(len(Xt))
                epoch_loss = 0.0
                correct = 0
                for i in range(0, len(Xt), batch):
                    idx = perm[i:i+batch]
                    xb = Xt[idx]
                    yb = yt[idx]
                    logits = net(xb)
                    loss = loss_fn(logits, yb)
                    opt.zero_grad(); loss.backward(); opt.step()
                    epoch_loss += loss.item() * len(idx)
                    correct += (logits.argmax(dim=1) == yb).sum().item()
                epoch_loss /= len(Xt)
                acc = correct / len(Xt)
                # log only from first ensemble member to avoid duplicating
                if k==0:
                    self.clf_loss_log.append(epoch_loss)
                    self.clf_acc_log.append(acc)
            self.clf_ensemble.append(net)
        return

    def clf_predict_proba_and_uncertainty(self, X_candidates):
        # returns mean_prob (N, n_classes), entropy (N,), epistemic_uncertainty via ensemble variance
        if len(self.clf_ensemble)==0:
            raise RuntimeError("Classifier not trained")
        Xs = self.x_scaler.transform(X_candidates)
        Xt = torch.tensor(Xs, dtype=torch.float32)
        logits_all = []
        with torch.no_grad():
            for net in self.clf_ensemble:
                logits = net(Xt).numpy()  # shape (N, C)
                probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
                logits_all.append(probs)
        probs_stack = np.stack(logits_all, axis=0)  # (ens, N, C)
        mean_probs = np.mean(probs_stack, axis=0)
        # entropy of mean probs
        eps = 1e-12
        entropy = -np.sum(mean_probs * np.log(mean_probs + eps), axis=1)
        # epistemic: variance across ensemble in prob for predicted class
        var = np.var(probs_stack, axis=0).mean(axis=1)  # average variance across classes
        return mean_probs, entropy, var

    # ---- surrogate ensemble training (per objective) ----
    def train_surrogate_ensemble(self, X, Y):
        n_ens = self.cfg["sur_ensembles"]
        self.sur_ensemble = []
        Xs = self.x_scaler.transform(X)
        Xt = torch.tensor(Xs, dtype=torch.float32)
        device = torch.device("cpu")
        for m in range(Y.shape[1]):
            y = Y[:, m]
            members = []
            for e in range(n_ens):
                net = SurrogateNet(self.D, hidden=128).to(device)
                opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)
                loss_fn = nn.MSELoss()
                epochs = max(10, self.cfg["sur_epochs"])  # keep not too small
                batch = self.cfg["sur_batch"]
                yt = torch.tensor(y, dtype=torch.float32)
                for ep in range(epochs):
                    perm = torch.randperm(len(Xt))
                    epoch_loss = 0.0
                    for i in range(0, len(Xt), batch):
                        idx = perm[i:i+batch]
                        xb = Xt[idx]
                        yb = yt[idx]
                        pred = net(xb)
                        loss = loss_fn(pred, yb)
                        opt.zero_grad(); loss.backward(); opt.step()
                        epoch_loss += loss.item() * len(idx)
                    epoch_loss /= len(Xt)
                members.append(net)
            self.sur_ensemble.append(members)
        # compute ensemble training MSE (mean across objectives)
        preds = []
        with torch.no_grad():
            for m in range(Y.shape[1]):
                members = self.sur_ensemble[m]
                preds_m = np.vstack([mem(torch.tensor(Xs, dtype=torch.float32)).numpy() for mem in members]).T
                preds.append(preds_m.mean(axis=1))
        preds = np.vstack(preds).T
        mse = np.mean((preds - Y)**2)
        self.sur_mse_log.append(mse)
        return

    def surrogate_predict_ensemble(self, X_candidates, return_samples=False):
        # returns mean (N,M) and var (N,M). If return_samples True returns samples shape (ens, N, M)
        Xs = self.x_scaler.transform(X_candidates)
        N = Xs.shape[0]
        M = self.M
        ens = self.cfg["sur_ensembles"]
        samples = np.zeros((ens, N, M))
        with torch.no_grad():
            for m in range(M):
                for e, net in enumerate(self.sur_ensemble[m]):
                    pred = net(torch.tensor(Xs, dtype=torch.float32)).numpy()
                    samples[e, :, m] = pred
        mean = samples.mean(axis=0)
        var = samples.var(axis=0)
        if return_samples:
            return mean, var, samples
        return mean, var

    # ---- learned acquisition: approximate expected HV improvement computed from surrogate samples ----
    def compute_expected_hv_improvement_from_samples(self, samples):
        # samples shape: (ens, N, M) -> compute HV improvement per candidate by averaging improvement across ensemble members
        ens, N, M = samples.shape
        current_front = self.archive_Y
        ref = (1.1, 1.1)
        base_hv = hv_2d(current_front, ref=ref)
        improvements = np.zeros(N)
        for e in range(ens):
            vals = samples[e, :, :]  # (N, M)
            # for each candidate, compute hv of archive + candidate
            for i in range(N):
                new_archive = np.vstack([current_front, vals[i:i+1, :]])
                new_hv = hv_2d(new_archive, ref=ref)
                improvements[i] += (new_hv - base_hv)
        improvements /= ens
        return improvements  # shape (N,)

    def train_acquisition_net(self, candidate_X_pool, acq_targets):
        # candidate_X_pool: (N, D) ; acq_targets: (N,)
        device = torch.device("cpu")
        # build features: for each candidate, use surrogate mean,var, classifier prob for class0, classifier entropy, ensemble var
        mean, var, samples = self.surrogate_predict_ensemble(candidate_X_pool, return_samples=True)
        probs, entropy, epvar = self.clf_predict_proba_and_uncertainty(candidate_X_pool)
        # features: [mean_flat, var_flat, prob0, entropy, epvar]
        N = candidate_X_pool.shape[0]
        feat = np.hstack([mean, var, probs[:, :1], entropy.reshape(-1,1), epvar.reshape(-1,1)])
        feat_t = torch.tensor(feat, dtype=torch.float32)
        target_t = torch.tensor(acq_targets, dtype=torch.float32)
        self.acq_net = AcquisitionNet(feat.shape[1], hidden=128)
        opt = optim.Adam(self.acq_net.parameters(), lr=1e-3)
        loss_fn = nn.MSELoss()
        batch = self.cfg["acq_batch"]
        epochs = self.cfg["acq_train_epochs"]
        for ep in range(epochs):
            perm = torch.randperm(len(feat_t))
            epoch_loss = 0.0
            for i in range(0, len(feat_t), batch):
                idx = perm[i:i+batch]
                xb = feat_t[idx]
                yb = target_t[idx]
                pred = self.acq_net(xb)
                loss = loss_fn(pred, yb)
                opt.zero_grad(); loss.backward(); opt.step()
                epoch_loss += loss.item() * len(idx)
            epoch_loss /= len(feat_t)
            self.acq_loss_log.append(epoch_loss)
        return

    def acquisition_score(self, X_cands):
        # compute learned acquisition score if trained else fallback to surrogate-based expected hv improvement
        if self.acq_net is None:
            # fallback: compute expected hv improvement by sampling surrogate ensemble directly
            mean, var, samples = self.surrogate_predict_ensemble(X_cands, return_samples=True)
            return self.compute_expected_hv_improvement_from_samples(samples)
        else:
            mean, var, samples = self.surrogate_predict_ensemble(X_cands, return_samples=True)
            probs, entropy, epvar = self.clf_predict_proba_and_uncertainty(X_cands)
            feat = np.hstack([mean, var, probs[:, :1], entropy.reshape(-1,1), epvar.reshape(-1,1)])
            with torch.no_grad():
                s = self.acq_net(torch.tensor(feat, dtype=torch.float32)).numpy()
            return s

    # ---- candidate generation and selection using combined score ----
    def generate_candidate_pool(self, pool_size):
        # sample random + DE-based offsprings from archive
        if len(self.archive_X) < 3:
            return lhs_samples(pool_size, self.D, seed=None)
        # pick some random
        rand_part = int(pool_size * 0.4)
        pool = []
        if rand_part>0:
            pool.append(lhs_samples(rand_part, self.D, seed=None))
        # DE offspring from selected parents from archive
        sel_idx = np.random.choice(len(self.archive_X), size=(pool_size - rand_part, 3), replace=True)
        parent1 = self.archive_X[sel_idx[:,0]]
        parent2 = self.archive_X[sel_idx[:,1]]
        parent3 = self.archive_X[sel_idx[:,2]]
        offs = operator_de(parent1, parent2, parent3)
        pool.append(offs)
        pool = np.vstack(pool)
        return pool[:pool_size]

    def select_candidates(self, n_select=1):
        # build candidate pool
        pool = self.generate_candidate_pool(self.cfg["cand_pool"])
        # classifier proba + uncertainty
        probs, entropy, epvar = self.clf_predict_proba_and_uncertainty(pool)
        # predicted prob of best class (class 0) and entropy/epvar
        prob0 = probs[:,0]
        # surrogate-based acquisition score (learned or direct)
        acq_score = self.acquisition_score(pool)
        # combined score: prefer high predicted Pareto prob, times acquisition score scaled, times uncertainty
        # normalize components
        eps = 1e-12
        p0_n = (prob0 - prob0.min()) / (prob0.max() - prob0.min() + eps)
        acq_n = (acq_score - acq_score.min()) / (acq_score.max() - acq_score.min() + eps)
        ent_n = (entropy - entropy.min()) / (entropy.max() - entropy.min() + eps)
        combined = p0_n * (0.6 * acq_n + 0.4 * ent_n)
        idx_sorted = np.argsort(combined)[::-1]
        chosen = pool[idx_sorted[:n_select]]
        return chosen

    # ---- main loop ----
    def run(self):
        print("Starting improved CLMEA run. CFG:", self.cfg)
        t0 = time.time()
        self.initial_sampling()
        iter_no = 0
        while self.FEs < self.maxFEs:
            iter_no += 1
            print(f"\n=== Iter {iter_no} | FEs: {self.FEs}/{self.maxFEs} | archive size: {len(self.archive_X)} ===")
            # 1) train classifier ensemble
            t1 = time.time()
            self.train_classifier_ensemble(self.archive_X, self.archive_Y)
            print(f" Trained classifier ensemble ({len(self.clf_ensemble)} members) in {time.time()-t1:.2f}s")
            # 2) train surrogate ensemble
            t2 = time.time()
            self.train_surrogate_ensemble(self.archive_X, self.archive_Y)
            print(f" Trained surrogate ensemble (per-objective members={self.cfg['sur_ensembles']}) in {time.time()-t2:.2f}s; surrogate MSE {self.sur_mse_log[-1]:.4e}")
            # 3) prepare acquisition training data (optional) - compute expected HV improvement for a pool and train acq net
            pool = self.generate_candidate_pool(self.cfg["cand_pool"])
            mean, var, samples = self.surrogate_predict_ensemble(pool, return_samples=True)
            acq_targets = self.compute_expected_hv_improvement_from_samples(samples)
            # normalize targets
            if np.all(acq_targets==0):
                acq_targets = acq_targets
            else:
                acq_targets = (acq_targets - acq_targets.min())/(acq_targets.max()-acq_targets.min()+1e-12)
            t3 = time.time()
            self.train_acquisition_net(pool, acq_targets)
            print(f" Trained acquisition net in {time.time()-t3:.2f}s (loss last {self.acq_loss_log[-1]:.4e})")
            # 4) select candidate via combined score
            x_sel = self.select_candidates(n_select=1)
            if x_sel.shape[0] > 0:
                y_new = zdt1(x_sel)
                self.archive_X = np.vstack([self.archive_X, x_sel])
                self.archive_Y = np.vstack([self.archive_Y, y_new])
                self.FEs += x_sel.shape[0]
                print(f" Evaluated {x_sel.shape[0]} new points.")
            else:
                print(" No candidate selected.")
            # safe break
            if iter_no > 1000:
                break
        print("Finished run: FEs =", self.FEs, " elapsed:", time.time()-t0)
        return

    # plotting logs
    def plot_metrics(self):
        if len(self.clf_loss_log)>0:
            fig, ax = plt.subplots(1,2, figsize=(12,4))
            ax[0].plot(self.clf_loss_log); ax[0].set_title("Classifier train loss")
            ax[1].plot(self.clf_acc_log); ax[1].set_title("Classifier train acc")
            plt.show()
        if len(self.sur_mse_log)>0:
            plt.figure(); plt.plot(self.sur_mse_log); plt.title("Surrogate MSE"); plt.show()
        if len(self.acq_loss_log)>0:
            plt.figure(); plt.plot(self.acq_loss_log); plt.title("Acquisition train loss"); plt.show()

# ---------------- run example ----------------
if __name__ == "__main__":
    cl = CLMEA_Improved(CFG)
    cl.run()
    cl.plot_metrics()
    frontno, _ = nondominated_sort(cl.archive_Y)
    pareto = cl.archive_Y[frontno==1]
    print("Pareto size:", pareto.shape[0])
