import torch
import gurobipy as gp
from gurobipy import GRB
from sklearn.decomposition import PCA


def geospca_solver(A, nc, k, epsilon, maxiter, device='cuda' if torch.cuda.is_available() else 'cpu'):

    A = A.to(device)
    n, p = A.shape


    dA = torch.sum(A**2, axis=0)


    results = {"iteration": 0, "Bindices": [], "Bvalue": 0}


    def disP(subA, ncA):
        if subA.shape[1] == 0:
            return torch.tensor([]), 0
        pca = PCA(n_components=ncA)
        transformed = pca.fit_transform(subA.cpu().numpy())
        reconstructed = pca.inverse_transform(transformed)
        diff = torch.tensor(reconstructed) - subA.cpu()
        objval = torch.sum(torch.tensor(pca.singular_values_)**2)
        norm = torch.sum(diff**2, axis=0)
        return norm, objval


    m = gp.Model()
    m.setParam('OutputFlag', 0)
    s = m.addVars(range(p), vtype=GRB.BINARY, name='s')
    m.setObjective(gp.quicksum(dA[i].item() * s[i] for i in range(p)), GRB.MAXIMIZE)
    m.addConstr(s.sum('*') <= k)


    def addCut(m, where):
        if where == GRB.Callback.MIPSOL and results["iteration"] < maxiter:
            sln = m.cbGetSolution(s)
            indices = [i for i in range(p) if sln[i] > 0.5]
            if indices:
                current_nc = min(nc, len(indices))
                excess, objval = disP(A[:, indices], current_nc)
                if objval >= results["Bvalue"]:
                    results["Bvalue"] = objval
                    results["Bindices"] = indices.copy()
                if torch.sum(excess) > epsilon:
                    m.cbLazy(gp.quicksum(s[i] for i in indices) <= len(indices) - 1)
            results["iteration"] += 1

    m.Params.LazyConstraints = 1
    m.optimize(addCut)



    v = m.getVars()
    sn = torch.zeros(p, device=device)
    indices = []
    for i in range(p):
        sn[i] = v[i].x
        if sn[i] > 0.5:
            indices.append(i)
    Ad = A[:, indices]

    pca = PCA(n_components=nc)
    pca.fit(Ad.cpu().numpy())

    best_indices = results["Bindices"]
    best_value = results["Bvalue"]
    iteration_count = results["iteration"]
    if best_indices:
        error, _ = disP(A[:, best_indices], nc)
        eta = torch.sum(error)
    else:
        eta = None

    return {
        "Bvalue": best_value,
        "Bindices": best_indices,
        "eta": eta,
        "iterations": iteration_count,
        "A": A
    }

