import torch
from dattri.model_util.retrain import retrain_lds
from dattri.benchmark.datasets.mnist import train_mnist_mlp
import logging
import sys
from dattri.metric.ground_truth import calculate_lds_ground_truth
from dattri.benchmark.load import load_benchmark
from torch.utils.data import DataLoader
import torch.nn.functional as F

logger = logging.getLogger("Train")
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

device = torch.device("cuda")

model_details, groundtruth = load_benchmark(model="mlp", dataset="mnist", metric="lds")

loader_train = DataLoader(
    model_details["train_dataset"],
    batch_size=5000,
    sampler=model_details["train_sampler"],
)
loader_test = DataLoader(
    model_details["test_dataset"], batch_size=500, sampler=model_details["test_sampler"]
)


def retrain(seed, subset_ratio):
    save_name = f"mnist_mlp_{subset_ratio}_seed{seed}"

    def phase1():
        def train_func(dataloader, seed=None):
            logger.info("Start train func")
            return train_mnist_mlp(
                dataloader=dataloader, seed=seed, device=device, epoch_num=50
            )

        retrain_lds(
            train_func=train_func,
            dataloader=loader_train,
            path=f"../{save_name}/checkpoints",
            seed=seed,
            num_subsets=50,
            subset_ratio=subset_ratio,
            num_runs_per_subset=1,
        )

    def phase2():
        def target_func(ckpt_path, dataloader):
            logger.info("Start target func")

            X_test, y_test = next(iter(dataloader))
            X_test = X_test.to(device)
            y_test = y_test.to(device)

            params = torch.load(ckpt_path)
            model = model_details["model"].to(device)
            model.load_state_dict(params)
            model.eval()
            with torch.no_grad():
                yhat = model(X_test)
            logp = -F.cross_entropy(yhat, y_test, reduction="none")  # (500,)
            assert logp.shape == (500,)
            return (logp - torch.log(1 - torch.exp(logp))).cpu()  # (500,)

        ground_truth = calculate_lds_ground_truth(
            target_func=target_func,
            retrain_dir=f"../{save_name}/checkpoints",
            test_dataloader=loader_test,
        )
        torch.save(ground_truth, f"../{save_name}/ground_truth.pt")

    phase1()
    phase2()


if __name__ == "__main__":
    # NOTE: candidate subset ratio (fraction)
    for ratio in [0.25, 0.5, 0.75, 1.0]:
        retrain(0, ratio)
        retrain(100, ratio)
