
import time
import numpy as np
from scipy.stats import qmc
from sklearn.kernel_ridge import KernelRidge
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

FAST_CONFIG = {
    "D": 30,
    "M": 2,
    "N_init": 30,        
    "NP": 20,            
    "maxFEs": 80,        
    "clf_epochs": 5,     
    "clf_batch": 64,    
    "hv_gen_max": 8,     
    "local_gen_max2": 3, 
    "surrogate_alpha": 1e-2, 
}

# ---------------- Problem: ZDT1 ----------------
def zdt1(x):
    D = x.shape[1]
    f1 = x[:, 0]
    g = 1 + 9.0 * np.sum(x[:, 1:], axis=1) / (D - 1)
    h = 1 - np.sqrt(f1 / g)
    f2 = g * h
    return np.vstack([f1, f2]).T

# ---------------- Utilities ----------------
def lhs_samples(n, D, lower=0.0, upper=1.0, seed=None):
    sampler = qmc.LatinHypercube(d=D, seed=seed)
    u = sampler.random(n)
    return qmc.scale(u, lower, upper)

def operator_de(parent1, parent2, parent3, CR=0.5, F=0.5, lower=0.0, upper=1.0):
    N, D = parent1.shape
    site = np.random.rand(N, D) < CR
    offspring = parent1.copy()
    offspring[site] = parent1[site] + F * (parent2[site] - parent3[site])
    offspring = np.clip(offspring, lower, upper)
    mut_mask = np.random.rand(N, D) < 1.0 / D
    if mut_mask.sum() > 0:
        offspring[mut_mask] += np.random.normal(scale=0.03, size=np.sum(mut_mask))
    return np.clip(offspring, lower, upper)

def nondominated_sort(objs):
    N = objs.shape[0]
    dom_count = np.zeros(N, dtype=int)
    dominates = [[] for _ in range(N)]
    for p in range(N):
        for q in range(N):
            if p==q: continue
            less = np.all(objs[p] <= objs[q])
            strict_less = np.any(objs[p] < objs[q])
            if less and strict_less:
                dominates[p].append(q)
            elif np.all(objs[q] <= objs[p]) and np.any(objs[q] < objs[p]):
                dom_count[p] += 1
    current = np.where(dom_count==0)[0].tolist()
    frontno = np.full(N, np.inf)
    fronts = []
    f = 1
    while current:
        fronts.append(current)
        for p in current:
            frontno[p] = f
        Q = []
        for p in current:
            for q in dominates[p]:
                dom_count[q] -= 1
                if dom_count[q]==0:
                    Q.append(q)
        current = Q
        f += 1
    return frontno.astype(int), fronts

def crowding_distance(objs, front_idx):
    if len(front_idx)==0:
        return np.array([])
    F = objs[front_idx]
    N, M = F.shape
    cd = np.zeros(N)
    for m in range(M):
        order = np.argsort(F[:, m])
        cd[order[0]] = np.inf
        cd[order[-1]] = np.inf
        fmin = F[order[0], m]
        fmax = F[order[-1], m]
        if fmax == fmin: continue
        for i in range(1, N-1):
            cd[order[i]] += (F[order[i+1], m] - F[order[i-1], m]) / (fmax - fmin)
    return cd

# ---------------- Simple MLP ----------------
class SimpleMLP(nn.Module):
    def __init__(self, D, hidden=64, n_classes=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(D, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_classes)
        )
    def forward(self, x):
        return self.net(x)

# ---------------- CLMEA Fast ----------------
class CLMEA_Fast:
    def __init__(self, cfg):
        self.D = cfg["D"]
        self.M = cfg["M"]
        self.N_init = cfg["N_init"]
        self.NP = cfg["NP"]
        self.maxFEs = cfg["maxFEs"]
        self.cfg = cfg
        np.random.seed(0)
        self.archive_X = None
        self.archive_Y = None
        self.class_loss_log = []
        self.class_acc_log = []
        self.surrogate_mse_log = []

    def initial_sampling(self):
        X0 = lhs_samples(self.N_init, self.D, seed=42)
        Y0 = zdt1(X0)
        self.archive_X = X0.copy()
        self.archive_Y = Y0.copy()
        self.FEs = self.N_init

    def train_classifier(self, X_train, Y_train):
        frontno, _ = nondominated_sort(Y_train)
        ranks = np.minimum(frontno, 4).astype(int) - 1
        scaler = StandardScaler()
        Xs = scaler.fit_transform(X_train)
        model = SimpleMLP(self.D, hidden=64, n_classes=4)
        loss_fn = nn.CrossEntropyLoss()
        opt = optim.Adam(model.parameters(), lr=1e-3)
        dataset = torch.tensor(Xs, dtype=torch.float32)
        labels = torch.tensor(ranks, dtype=torch.long)
        epochs = self.cfg["clf_epochs"]
        batch = self.cfg["clf_batch"]
        for ep in range(epochs):
            perm = torch.randperm(len(dataset))
            epoch_loss = 0.0
            correct = 0
            for i in range(0, len(dataset), batch):
                idx = perm[i:i+batch]
                xb = dataset[idx]
                yb = labels[idx]
                logits = model(xb)
                loss = loss_fn(logits, yb)
                opt.zero_grad()
                loss.backward()
                opt.step()
                epoch_loss += loss.item() * len(idx)
                correct += (logits.argmax(dim=1) == yb).sum().item()
            epoch_loss /= len(dataset)
            acc = correct / len(dataset)
            self.class_loss_log.append(epoch_loss)
            self.class_acc_log.append(acc)
        return model, scaler

    def train_surrogates(self, X_train, Y_train):
        kr_models = []
        scalerX = StandardScaler().fit(X_train)
        Xs = scalerX.transform(X_train)
        alpha = self.cfg.get("surrogate_alpha", 1e-2)
        for m in range(Y_train.shape[1]):
            kr = KernelRidge(alpha=alpha, kernel='rbf', gamma=0.5)
            kr.fit(Xs, Y_train[:, m])
            kr_models.append(kr)
        preds = np.vstack([m.predict(Xs) for m in kr_models]).T
        self.surrogate_mse_log.append(np.mean((preds - Y_train)**2))
        return kr_models, scalerX

    def classifier_assisted_infill(self, model, scaler, N=20, num_infill=1):
        frontno, _ = nondominated_sort(self.archive_Y)
        sorted_inds = np.argsort(frontno)
        P_idx = sorted_inds[:min(len(sorted_inds), N)]
        P = self.archive_X[P_idx]
        MLP_in = scaler.transform(P)
        with torch.no_grad():
            logits = model(torch.tensor(MLP_in, dtype=torch.float32)).numpy()
        labels = np.argmax(logits, axis=1)
        rank1 = P[labels==0]
        rank2 = P[labels<=1]
        if len(rank1) == 0:
            rank1 = P[np.random.randint(0, len(P), size=(max(1, len(P)),))]
        if len(rank2) == 0:
            rank2 = P
        Off = operator_de(
            parent1 = rank1[np.random.randint(0, len(rank1), size=(N,))],
            parent2 = rank1[np.random.randint(0, len(rank1), size=(N,))],
            parent3 = rank2[np.random.randint(0, len(rank2), size=(N,))],
            CR=0.5, F=0.5
        )
        Off_s = scaler.transform(Off)
        with torch.no_grad():
            off_logits = model(torch.tensor(Off_s, dtype=torch.float32)).numpy()
        off_labels = np.argmax(off_logits, axis=1)
        chosen = np.where(off_labels==0)[0]
        if len(chosen) < int(0.5*N):
            probs = np.exp(off_logits) / np.sum(np.exp(off_logits), axis=1, keepdims=True)
            scores = probs[:, 0]
            chosen = np.argsort(scores)[-min(len(scores), int(0.5*N)):]
        dists = np.min(np.linalg.norm(Off[:,None,:] - self.archive_X[None,:,:], axis=2), axis=1)
        chosen_uncertain_idx = chosen[np.argsort(dists[chosen])[::-1]]
        sel = Off[chosen_uncertain_idx[:num_infill]]
        return sel

    def hv_based_search(self, kr_models, scalerX, N=20, num_infill=1):
        X_parent_idx = self.select_train_data(N)
        X_parent = self.archive_X[X_parent_idx]
        y_pred = np.vstack([m.predict(scalerX.transform(X_parent)) for m in kr_models]).T
        x_parent = X_parent.copy()
        y_parent = y_pred.copy()
        for g in range(self.cfg["hv_gen_max"]):
            if len(x_parent) < 2: break
            idx1 = np.random.permutation(len(x_parent))
            idx2 = np.random.permutation(len(x_parent))
            x_off = operator_de(x_parent, x_parent[idx1], x_parent[idx2])
            y_off = np.vstack([m.predict(scalerX.transform(x_off)) for m in kr_models]).T
            med_dec = np.vstack([x_parent, x_off])
            med_obj = np.vstack([y_parent, y_off])
            keep_idx = self._select_by_front_and_crowding(med_obj, min(len(med_dec), N))
            x_parent = med_dec[keep_idx]
            y_parent = med_obj[keep_idx]
        front_pred = x_parent[nondominated_sort(y_parent)[0]==1]
        cand = []
        for j in range(min(len(front_pred), num_infill)):
            cand.append(front_pred[j])
        return np.array(cand)

    def local_infill(self, kr_models, scalerX, num_infill=1, N=80, k_local=10):
        frontno, _ = nondominated_sort(self.archive_Y)
        pareto_idx = np.where(frontno==1)[0]
        if len(pareto_idx)==0:
            return np.empty((0, self.D))
        Pareto = self.archive_X[pareto_idx]
        ParetoObj = self.archive_Y[pareto_idx]
        cds = crowding_distance(ParetoObj, list(range(len(Pareto))))
        if len(cds)==0:
            return np.empty((0, self.D))
        order = np.argsort(cds)[::-1]
        sel_idx = order[:min(num_infill, len(order))]
        candidates = []
        for i in sel_idx:
            ref_obj = ParetoObj[i]
            dists = np.linalg.norm(self.archive_Y - ref_obj, axis=1)
            nearest_idx = np.argsort(dists)[:N]
            x_train = self.archive_X[nearest_idx]
            y_train = self.archive_Y[nearest_idx]
            kr_local, scaler_loc = self.train_surrogates(x_train, y_train)
            x_parent_idx = self.select_train_data(min(k_local, len(x_train)), local_X=x_train, local_Y=y_train)
            if isinstance(x_parent_idx, np.ndarray) and len(x_parent_idx)>0:
                x_parent = x_train[x_parent_idx]
            else:
                x_parent = x_train[:min(k_local, len(x_train))].copy()
            if len(x_parent)==0:
                continue
            y_parent = np.vstack([m.predict(scaler_loc.transform(x_parent)) for m in kr_local]).T
            # shorter local evolution
            for g in range(self.cfg["local_gen_max2"]):
                if len(x_parent) < 2: break
                x_off = operator_de(
                    parent1 = np.tile(Pareto[i], (len(x_parent),1)),
                    parent2 = x_parent[np.random.randint(0, len(x_parent), size=(len(x_parent),))],
                    parent3 = x_parent[np.random.randint(0, len(x_parent), size=(len(x_parent),))],
                    CR=0.5, F=0.5,
                    lower=np.min(x_train, axis=0), upper=np.max(x_train, axis=0)
                )
                y_off = np.vstack([m.predict(scaler_loc.transform(x_off)) for m in kr_local]).T
                med_dec = np.vstack([x_parent, x_off])
                med_obj = np.vstack([y_parent, y_off])
                keep_idx = self._select_by_front_and_crowding(med_obj, min(len(x_parent), med_dec.shape[0]))
                x_parent = med_dec[keep_idx]
                y_parent = med_obj[keep_idx]
            final_y = np.vstack([m.predict(scaler_loc.transform(x_parent)) for m in kr_local]).T
            dists_to_eval = np.min(np.linalg.norm(self.archive_Y[:, None, :] - final_y[None, :, :], axis=2), axis=0)
            choose = np.argmax(dists_to_eval)
            candidates.append(x_parent[choose])
        if len(candidates)==0:
            return np.empty((0, self.D))
        return np.array(candidates)

    def select_train_data(self, N, local_X=None, local_Y=None):
        if local_X is None:
            tr_x = self.archive_X
            tr_y = self.archive_Y
        else:
            tr_x = local_X
            tr_y = local_Y
        frontno, _ = nondominated_sort(tr_y)
        chosen = []
        maxF = 1
        while len(chosen) < N and maxF <= max(frontno):
            idxs = np.where(frontno==maxF)[0]
            if len(chosen) + len(idxs) <= N:
                chosen += idxs.tolist()
            else:
                cds = crowding_distance(tr_y, idxs)
                pick_order = np.argsort(cds)[::-1]
                need = N - len(chosen)
                chosen += list(np.array(idxs)[pick_order[:need]])
            maxF += 1
        if len(chosen) > N:
            chosen = chosen[:N]
        return np.array(chosen)

    def _select_by_front_and_crowding(self, objs, N):
        frontno, _ = nondominated_sort(objs)
        chosen = []
        f = 1
        while len(chosen) < N and f <= max(frontno):
            idxs = np.where(frontno==f)[0]
            if len(chosen) + len(idxs) <= N:
                chosen += idxs.tolist()
            else:
                cds = crowding_distance(objs, idxs)
                pick_order = np.argsort(cds)[::-1]
                need = N - len(chosen)
                chosen += list(np.array(idxs)[pick_order[:need]])
            f += 1
        if len(chosen) > N:
            chosen = chosen[:N]
        return np.array(chosen)

    def run(self):
        print("Starting CLMEA fast run with config:", self.cfg)
        t0 = time.time()
        self.initial_sampling()
        iter_no = 0
        while self.FEs < self.maxFEs:
            iter_no += 1
            iter_t0 = time.time()
            print(f"\n=== Iter {iter_no} | FEs so far: {self.FEs}/{self.maxFEs} ===")
            # 1) classifier
            t1 = time.time()
            clf_model, clf_scaler = self.train_classifier(self.archive_X, self.archive_Y)
            print(f"Classifier trained (time {time.time()-t1:.2f}s).")
            # 2) global surrogate
            t2 = time.time()
            kr_models, scalerX = self.train_surrogates(self.archive_X, self.archive_Y)
            print(f"Surrogates trained (time {time.time()-t2:.2f}s).")
            # 3) classifier-assisted infill
            t3 = time.time()
            x_cand1 = self.classifier_assisted_infill(clf_model, clf_scaler, N=self.NP, num_infill=1)
            if x_cand1.shape[0]>0:
                y_new = zdt1(x_cand1)
                self.archive_X = np.vstack([self.archive_X, x_cand1])
                self.archive_Y = np.vstack([self.archive_Y, y_new])
                self.FEs += x_cand1.shape[0]
                print(f"Added {x_cand1.shape[0]} candidate(s) from classifier (time {time.time()-t3:.2f}s).")
            if self.FEs >= self.maxFEs: break
            # 4) hv-based search
            t4 = time.time()
            x_cand2 = self.hv_based_search(kr_models, scalerX, N=self.NP, num_infill=1)
            if x_cand2.shape[0]>0:
                y_new = zdt1(x_cand2)
                self.archive_X = np.vstack([self.archive_X, x_cand2])
                self.archive_Y = np.vstack([self.archive_Y, y_new])
                self.FEs += x_cand2.shape[0]
                print(f"Added {x_cand2.shape[0]} candidate(s) from HV search (time {time.time()-t4:.2f}s).")
            if self.FEs >= self.maxFEs: break
            # 5) local infill
            t5 = time.time()
            x_cand3 = self.local_infill(kr_models, scalerX, num_infill=1, N=min(80, len(self.archive_X)), k_local=10)
            if x_cand3.shape[0]>0:
                y_new = zdt1(x_cand3)
                self.archive_X = np.vstack([self.archive_X, x_cand3])
                self.archive_Y = np.vstack([self.archive_Y, y_new])
                self.FEs += x_cand3.shape[0]
                print(f"Added {x_cand3.shape[0]} candidate(s) from local infill (time {time.time()-t5:.2f}s).")
            print(f"Iteration {iter_no} finished in {time.time()-iter_t0:.2f}s. Total elapsed {time.time()-t0:.2f}s.")
            if iter_no > 500:
                break
        print("Finished: FEs =", self.FEs, " total time:", time.time()-t0)

    def plot_metrics(self):
        fig, ax = plt.subplots(1,2, figsize=(12,4))
        ax[0].plot(self.class_loss_log, label='classifier loss')
        ax[0].set_title('Classifier Loss (per epoch)')
        ax[0].legend()
        ax[1].plot(self.class_acc_log, label='classifier acc')
        ax[1].set_title('Classifier Accuracy (per epoch)')
        ax[1].legend()
        plt.show()
        if len(self.surrogate_mse_log)>0:
            plt.figure()
            plt.plot(self.surrogate_mse_log, label='surrogate train MSE')
            plt.title('Surrogate train MSE')
            plt.legend()
            plt.show()

# ---------- run ----------
if __name__ == "__main__":
    cfg = FAST_CONFIG.copy()
    cl = CLMEA_Fast(cfg)
    cl.run()
    cl.plot_metrics()
    frontno, _ = nondominated_sort(cl.archive_Y)
    pareto = cl.archive_Y[frontno==1]
    print("Pareto size:", pareto.shape[0])
