import bz2
import os
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import LinearOperator
from scipy.sparse.linalg import eigsh
from tqdm import tqdm
import urllib.request
import torch
from iht.linalg import accelerated_astra
from iht.linalg import accelerated_iht_algorithm
from iht.linalg import accelerated_strong_astra

# from iht.linalg import accelerated_wanda_algorithm
from iht.linalg import sparse_gpt_pruner

from tabulate import tabulate
from tests.old.utils import measure

from sklearn.datasets import load_svmlight_file

KAPPAS = [0.01, 0.1, 0.25, 0.5, 0.75]
# KAPPAS = [0.01]

LIBSVM_PATH = "/buckets/datasets/libsvm"
DATASETS = {
    "gisette": "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/gisette_scale.bz2",
    # "legdar": "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/ledgar_lexglue_tfidf_train.svm.bz2",
    # "real-sim": "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/real-sim.bz2",
    # this one is ~ 4GB, so it might take a while to download and decompress
    # "epsilon": "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/epsilon_normalized.bz2",
}


def ensure_dataset_exists(name):
    """Checks if a dataset is available, and if not, downloads and extracts it."""
    if name not in DATASETS:
        raise ValueError(f"Unknown dataset: {name}")

    url = DATASETS[name]
    filename = os.path.basename(url)
    extracted_filename = filename.replace(".bz2", "")

    local_path = os.path.join(LIBSVM_PATH, extracted_filename)
    compressed_local_path = os.path.join(LIBSVM_PATH, filename)

    if not os.path.exists(local_path):
        print(f"Dataset {name} not found locally. Downloading from {url}...")
        os.makedirs(LIBSVM_PATH, exist_ok=True)

        with tqdm(
            unit="B", unit_scale=True, miniters=1, desc=f"Downloading {name}"
        ) as t:

            def report_hook(block_num, block_size, total_size):
                if t.total is None and total_size > 0:
                    t.total = total_size
                t.update(block_size)

            urllib.request.urlretrieve(
                url, compressed_local_path, reporthook=report_hook
            )

        print(f"Decompressing {compressed_local_path}...")
        with bz2.BZ2File(compressed_local_path, "rb") as f_in:
            with open(local_path, "wb") as f_out:
                f_out.write(f_in.read())
        print("Decompression complete.")
        os.remove(compressed_local_path)
    return local_path


def hard_threshold(vec, k):
    """Keeps the k largest (in magnitude) elements of a vector."""
    if k >= len(vec):
        return vec
    indices = np.argsort(np.abs(vec))[-k:]
    result = np.zeros_like(vec, dtype=float)
    result[indices] = vec[indices]
    return result


def soft_threshold(vec, threshold):
    """Applies the soft-thresholding operator element-wise."""
    return np.sign(vec) * np.maximum(np.abs(vec) - threshold, 0)


def objective_function(x, A, y):
    """Computes ||Ax - y||^2"""
    return 0.5 * np.mean((A @ x - y) ** 2)


def get_gradient(x, A, y):
    """Computes the gradient of the objective function."""
    return A.T @ (A @ x - y)


def at_operator(A, v):
    return A.T @ (A @ v)


if __name__ == "__main__":
    for dataset_name in DATASETS:
        data_path = ensure_dataset_exists(dataset_name)
        print(f"Loading dataset {dataset_name} from {data_path}...")
        sp_X, sp_y = load_svmlight_file(data_path)  # type: ignore
        sp_X: sp.csr_matrix = sp.csr_matrix(sp_X)
        sp_y = np.array(sp_y)
        print(
            f"Dataset {dataset_name} loaded with shapes: {sp_X.shape}, {sp_y.shape}"
        )
        print(
            "Unique labels:",
            np.unique(sp_y),
            "num_classes:",
            len(np.unique(sp_y)),
        )

        sp_sol = sp.linalg.lsqr(
            sp_X, sp_y, atol=1e-6, btol=1e-6, iter_lim=1000
        )[0]
        x_sol = torch.tensor(sp_sol, dtype=torch.float64, device="cuda")
        print("Solution norm:", np.linalg.norm(sp_sol))
        print("Objective value:", objective_function(sp_sol, sp_X, sp_y))

        s_X = sp_X.copy()
        s_X.data = s_X.data**2
        diag_H = s_X.sum(axis=0).A1

        def at_op(v):
            return sp_X.T @ (sp_X @ v)

        n_features = sp_X.shape[1]  # type: ignore
        T_op = LinearOperator((n_features, n_features), matvec=at_op)  # type: ignore

        # k=1 means we only want the largest eigenvalue.
        # which='LM' specifies to find the largest magnitude eigenvalue.
        largest_eigenvalue = float(
            eigsh(T_op, k=1, which="LM", return_eigenvectors=False)[0]
        )
        eta = 1.0 / largest_eigenvalue
        print(f"Step size eta is : {eta}")

        import torch

        print("Converting to dense format for PyTorch...")
        X = torch.tensor(sp_X.toarray(), dtype=torch.float64, device="cuda")
        y = torch.tensor(sp_y, dtype=torch.float64, device="cuda")
        results = []

        for kappa in KAPPAS:
            max_iter = 20000
            k = int(n_features * kappa)
            print(
                f"Running with k={k} and eta={eta} for {max_iter} iterations..."
            )

            # w_iht = accelerated_iht_algorithm(
            #     X, y, k, eta=eta, max_iter=max_iter, rtol=1e-7,
            #     #   w_0=x_sol
            # )

            # w_iht = accelerated_wanda_algorithm(
            #     X, y, (X**2).sum(0), k, eta=eta, max_iter=max_iter, rtol=1e-6
            # )

            w_astra = accelerated_astra(
                X,
                y,
                (X**2).sum(0),
                k,
                eta=eta,
                beta=1e-3,
                max_iter=max_iter*2,
                # w_0=x_sol,
                rtol=1e-8,
                # x_sol=x_sol,
                # recovery_lambda=1e-3,
            )

            w_iht = accelerated_strong_astra(
                X,
                y,
                (X**2).sum(0),
                k,
                eta=eta,
                beta=1e-3,
                max_iter=max_iter*2,
                # w_0=x_sol,
                rtol=1e-8,
                # x_sol=x_sol,
                # recovery_lambda=1e-3,
            )

            iht_err = torch.mean((X @ w_iht - y) ** 2).item()
            astra_err = torch.mean((X @ w_astra - y) ** 2).item()

            print(f"IHT error: {iht_err}")
            print(f"Astra error: {astra_err}")

            iht_support = (
                torch.nonzero(w_iht, as_tuple=False).flatten().cpu().numpy()
            )
            astra_support = (
                torch.nonzero(w_astra, as_tuple=False).flatten().cpu().numpy()
            )

            iht_X = sp_X[:, iht_support]
            astra_X = sp_X[:, astra_support]

            with measure("IHT+ on reduced problem with w_0"):
                iht_sol = sp.linalg.lsqr(
                    iht_X,
                    sp_y,
                    atol=1e-6,
                    btol=1e-6,
                    iter_lim=1000,
                    x0=w_iht.cpu().numpy()[iht_support],
                )[0]

            with measure("ASTRA+ on reduced problem with w_0"):
                astra_sol = sp.linalg.lsqr(
                    astra_X,
                    sp_y,
                    atol=1e-6,
                    btol=1e-6,
                    iter_lim=1000,
                    x0=w_astra.cpu().numpy()[astra_support],
                )[0]

            iht_err_p = np.mean((iht_X @ iht_sol - sp_y) ** 2)
            print(f"IHT+ Error on the reduced problem: {iht_err_p}")
            astra_err_p = np.mean((astra_X @ astra_sol - sp_y) ** 2)
            print(f"ASTRA+ Error on the reduced problem: {astra_err_p}")

            # with measure("IHT+ on reduced problem w/o w_0"):
            #     iht_sol = sp.linalg.lsqr(
            #         iht_X, sp_y, atol=1e-6, btol=1e-6, iter_lim=1000
            #     )[0]

            # with measure("ASTRA+ on reduced problem w/o w_0"):
            #     astra_sol = sp.linalg.lsqr(
            #         astra_X, sp_y, atol=1e-6, btol=1e-6, iter_lim=1000
            #     )[0]
            # iht_err_p = np.mean((iht_X @ iht_sol - sp_y) ** 2)
            # print(f"IHT+ Error on the reduced problem: {iht_err_p}")
            # astra_err_p = np.mean((astra_X @ astra_sol - sp_y) ** 2)
            # print(f"ASTRA+ Error on the reduced problem: {astra_err_p}")

            scores = np.abs(sp_sol) * np.sqrt(diag_H)
            wanda_support = np.argsort(scores)[-k:]
            wanda_X = sp_X[:, wanda_support]

            wanda_err = np.mean((wanda_X @ sp_sol[wanda_support] - sp_y) ** 2)
            print(f"Wanda err on the reduced problem: {wanda_err}")

            with measure("Wanda on reduced problem"):
                wanda_sol = sp.linalg.lsqr(
                    wanda_X,
                    sp_y,
                    atol=1e-6,
                    btol=1e-6,
                    iter_lim=1000,
                    x0=sp_sol[wanda_support],
                )[0]
            wanda_err_p = np.mean((wanda_X @ wanda_sol - sp_y) ** 2)
            print(f"Wanda+ on the reduced problem: {wanda_err_p}")

            with measure("Sparse GPT Pruner"):
                w_gpt = sparse_gpt_pruner(X, y, x_sol, k)

            gpt_support = (
                torch.nonzero(w_gpt, as_tuple=False).flatten().cpu().numpy()
            )
            print(f"GPT Error: {torch.mean((X @ w_gpt - y) ** 2).item()}")
            gpt_X = sp_X[:, gpt_support]
            with measure("GPT+ on reduced problem"):
                gpt_sol = sp.linalg.lsqr(
                    gpt_X,
                    sp_y,
                    atol=1e-6,
                    btol=1e-6,
                    iter_lim=1000,
                    x0=w_gpt.cpu().numpy()[gpt_support],
                )[0]
            gpt_err_p = np.mean((gpt_X @ gpt_sol - sp_y) ** 2)
            print(f"GPT+ Error on the reduced problem: {gpt_err_p}")

            results.append(
                [
                    kappa,
                    k,
                    iht_err,
                    wanda_err,
                    astra_err,
                    iht_err_p,
                    wanda_err_p,
                    astra_err_p,
                    gpt_err_p,
                ]
            )
        print(
            tabulate(
                results,
                headers=[
                    "Kappa",
                    "k",
                    "IHT Error",
                    "Wanda Error",
                    "ASTRA Error",
                    "IHT+ Error",
                    "Wanda+ Error",
                    "ASTRA+ Error",
                    "Sparse GPT Error",
                ],
            )
        )
