import torch
from dattri.model_util.retrain import retrain_lds
import logging
import sys
from dattri.metric.ground_truth import calculate_lds_ground_truth
from sklearn.linear_model import LogisticRegression
from dattri.benchmark.load import load_benchmark
from torch.utils.data import DataLoader
import torch.nn.functional as F
from dattri.benchmark.datasets.mnist import create_lr_model
from dattri.func.utils import flatten_params, _unflatten_params

logger = logging.getLogger("Train")
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

device = torch.device("cuda")

model_details, groundtruth = load_benchmark(model="lr", dataset="mnist", metric="lds")

loader_train = DataLoader(
    model_details["train_dataset"],
    batch_size=5000,
    sampler=model_details["train_sampler"],
)
loader_test = DataLoader(
    model_details["test_dataset"], batch_size=500, sampler=model_details["test_sampler"]
)


def retrain(seed, subset_ratio):
    save_name = f"mnist_lr_{subset_ratio}_seed{seed}"
    model = create_lr_model("mnist").eval()

    def phase1():
        def train_func(dataloader, seed=None):
            if seed is None:
                seed = 0
            logger.info("Start train func")
            X, y = next(iter(dataloader))
            X = X.reshape(int(subset_ratio * 5000), -1)
            log_reg = LogisticRegression(
                multi_class="multinomial",
                solver="lbfgs",
                max_iter=1000,
                random_state=0,
                C=1e8,
            )
            log_reg.fit(X.cpu().numpy(), y.cpu().numpy())
            w = torch.from_numpy(log_reg.coef_).reshape(-1)
            b = torch.from_numpy(log_reg.intercept_).reshape(-1)
            wb = torch.cat((w, b), dim=0).to(torch.float32)
            return _unflatten_params(wb, model)

        retrain_lds(
            train_func=train_func,
            dataloader=loader_train,
            path=f"../{save_name}/checkpoints",
            seed=seed,
            num_subsets=50,
            subset_ratio=subset_ratio,
            num_runs_per_subset=1,
            only_weights=True,
        )

    def phase2():
        def target_func(ckpt_path, dataloader):
            logger.info("Start target func")

            X_test, y_test = next(iter(dataloader))
            X_test = X_test.reshape(500, -1).to(device)
            y_test = y_test.to(device)

            theta = flatten_params(torch.load(ckpt_path)).to(device)
            w = theta[:7840].reshape(10, 784).to(torch.float32)
            b = theta[7840:].reshape(10).to(torch.float32)
            yhat = X_test @ w.T + b  # (500, 10)
            logp = -F.cross_entropy(yhat, y_test, reduction="none")  # (500,)
            return (logp - torch.log(1 - torch.exp(logp))).cpu()  # (500,)

        ground_truth = calculate_lds_ground_truth(
            target_func=target_func,
            retrain_dir=f"../{save_name}/checkpoints",
            test_dataloader=loader_test,
        )
        torch.save(ground_truth, f"../{save_name}/ground_truth.pt")

    phase1()
    phase2()


if __name__ == "__main__":
    # NOTE: candidate subset ratio (fraction)
    # for ratio in [0.25, 0.5, 0.75, 1.0]:
    #     retrain(0, ratio)
    #     retrain(100, ratio)
    retrain(0, 1.0)
    retrain(100, 0.5)
