import torch
from dattri.model_util.retrain import retrain_lds
from dattri.benchmark.datasets.cifar import train_cifar_resnet9
import logging
import sys
from dattri.metric.ground_truth import calculate_lds_ground_truth
from dattri.benchmark.utils import SubsetSampler
from dattri.benchmark.load import load_benchmark
from torch.utils.data import DataLoader
import torch.nn.functional as F

logger = logging.getLogger("Train")
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

device = torch.device("cuda")


def retrain(seed, subset_ratio, train_size, test_size):
    save_name = f"cifar2_resnet9_{subset_ratio}_seed{seed}_size{train_size}_{test_size}"
    print(f"Start task '{save_name}'")

    model_details = load_benchmark(model="resnet9", dataset="cifar2", metric="lds")[0]
    model_details["train_sampler"] = SubsetSampler(range(train_size))
    model_details["test_sampler"] = SubsetSampler(range(test_size))
    loader_train = DataLoader(
        model_details["train_dataset"],
        batch_size=train_size,
        sampler=model_details["train_sampler"],
    )
    loader_test = DataLoader(
        model_details["test_dataset"],
        batch_size=test_size,
        sampler=model_details["test_sampler"],
    )

    def phase1():
        def train_func(dataloader, seed=None):
            logger.info("Start train func")
            return train_cifar_resnet9(
                dataloader=dataloader, seed=seed, device=device, num_epochs=50
            )

        retrain_lds(
            train_func=train_func,
            dataloader=loader_train,
            path=f"../{save_name}/checkpoints",
            seed=seed,
            num_subsets=50,
            subset_ratio=subset_ratio,
            num_runs_per_subset=1,
        )

    def phase2():
        def target_func(ckpt_path, dataloader):
            logger.info("Start target func")

            X_test, y_test = next(iter(dataloader))
            X_test = X_test.to(device)
            y_test = y_test.to(device)

            params = torch.load(ckpt_path)
            model = model_details["model"].to(device)
            model.load_state_dict(params)
            model.eval()
            with torch.no_grad():
                yhat = model(X_test)
            logp = -F.cross_entropy(yhat, y_test, reduction="none")
            return (logp - torch.log(1 - torch.exp(logp))).cpu()

        ground_truth = calculate_lds_ground_truth(
            target_func=target_func,
            retrain_dir=f"../{save_name}/checkpoints",
            test_dataloader=loader_test,
        )
        torch.save(ground_truth, f"../{save_name}/ground_truth.pt")

    phase1()
    phase2()


if __name__ == "__main__":
    retrain(0, 0.5, 1000, 100)
    retrain(0, 1.0, 1000, 100)
