import numpy as np
import scipy.linalg
import sklearn.linear_model
import sklearn.metrics
import torch as th

from learned_planners.interp.train_probes import TrainProbeConfig
from learned_planners.interp.utils import get_metrics
from learned_planners.notebooks.emacs_plotly_render import set_plotly_renderer

set_plotly_renderer("emacs")
th.set_grad_enabled(False)

# %%

acts_ds = th.load("/training/activations_dataset/hard/0ts_boxes_future_direction_map_5000_skip5.pt")
# %% extract training data from the dataset & construct train/test


def process_acts_ds(acts_ds, keys_for_dataset, idx):
    x_data = (acts_ds[i][0] for i in idx)
    acts_X = np.stack([np.concatenate([x[k] for k in keys_for_dataset]) for x in x_data])
    y_data = (acts_ds[i][1] for i in idx)
    acts_y = np.stack(list(y_data))

    acts_y_multioutput = np.eye(4)[acts_y]
    acts_y_multioutput[acts_y == -1, :] = 0
    return acts_X, acts_y  # acts_y_multioutput


keys_for_dataset = sum(
    ([f"features_extractor.cell_list.{i}.hook_h", f"features_extractor.cell_list.{i}.hook_c"] for i in range(3)), start=[]
)

all_data_idx = np.random.permutation(len(acts_ds))
train_idx = all_data_idx[:100000]
test_idx = all_data_idx[len(train_idx) :]

train_X, train_y = process_acts_ds(acts_ds, keys_for_dataset, train_idx)
test_X, test_y = process_acts_ds(acts_ds, keys_for_dataset, test_idx)

# %% Train the probe

args = TrainProbeConfig(
    weight_decay_type="l1",
    weight_decay=1,
    sklearn_class_weight=None,
    sklearn_solver="saga",
    sklearn_n_jobs=4,
    sklearn_l1_ratio=1.0,
)

probe = sklearn.linear_model.LogisticRegression(
    penalty=args.weight_decay_type,
    C=1 / args.weight_decay,
    class_weight=args.sklearn_class_weight,
    solver=args.sklearn_solver,
    n_jobs=args.sklearn_n_jobs,
    l1_ratio=args.sklearn_l1_ratio,
    fit_intercept=False,
)
probe.fit(train_X, train_y)

# %% calculate metrics

train_preds = probe.predict(train_X)
test_preds = probe.predict(test_X)

test_metrics = get_metrics(test_preds, test_y, True, "test")
train_metrics = get_metrics(train_preds, train_y, True, "train")
metrics = {**test_metrics, **train_metrics}
metrics

# %%

assert np.all(probe.intercept_ == 0)

probe_v = probe.coef_
norm_probe = probe_v / np.linalg.norm(probe_v, 2, axis=1)[:, None]
assert np.allclose(np.linalg.norm(norm_probe, 2, axis=1), 1)

direction_mask = train_X @ probe_v.T + probe.intercept_
# %% Apply LEACE to erase probe
_, _, v = np.linalg.svd(probe_v)
P = v.T @ v
whiten_X = np.linalg.cholesky(train_X.T @ train_X)
whitened_X = scipy.linalg.solve_triangular(whiten_X, train_X.T, lower=True).T
assert np.allclose(whitened_X.T @ whitened_X, np.eye(train_X.shape[1]), atol=1e-4)
assert whitened_X.shape == train_X.shape

eliminated_information = train_X - whitened_X @ P @ whiten_X.T

# %%


probe2 = sklearn.linear_model.LogisticRegression(
    penalty=args.weight_decay_type,
    C=1 / args.weight_decay,
    class_weight=args.sklearn_class_weight,
    solver=args.sklearn_solver,
    n_jobs=args.sklearn_n_jobs,
    l1_ratio=args.sklearn_l1_ratio,
    fit_intercept=False,
)
probe2.fit(eliminated_information, train_y)
# %% calc metrics

train_preds = probe2.predict(train_X)
eliminated_preds = probe2.predict(eliminated_information)
test_preds = probe2.predict(test_X)


def sk_get_metrics(preds, true_y, _unused, name) -> dict:
    return {
        f"{name}/f1": sklearn.metrics.f1_score(true_y, preds, average="macro"),
        f"{name}/precision": sklearn.metrics.precision_score(true_y, preds, average="macro"),
        f"{name}/recall": sklearn.metrics.recall_score(true_y, preds, average="macro"),
        f"{name}/accuracy": sklearn.metrics.accuracy_score(true_y, preds),
    }


train_metrics = sk_get_metrics(train_preds, train_y, True, "train")
eliminated_metrics = sk_get_metrics(eliminated_preds, train_y, True, "eliminated")
test_metrics = sk_get_metrics(test_preds, test_y, True, "test")
metrics = {**train_metrics, **test_metrics, **eliminated_metrics}
metrics
# %% check that probes are orthogonal (Actually the new probe is all zeros)

_, s, _ = np.linalg.svd(np.concatenate([probe.coef_, probe2.coef_], axis=0))
# assert np.all(s > 0), "new probe is not orthogonal"
assert np.all(probe2.coef_ == 0)


# %%

acts_ds.data[0].keys()
