import argparse
import os
import sys

import torch
from fastargs import get_current_config, Param, Section
from fastargs.decorators import param
from fastargs.validation import OneOf, File, Or, Folder, And

sys.path.append(".")
from src.tools.sharpness import measure_sharpness

sys.path.append(".")

Section('cfg').params(
    architecture=Param(OneOf(['resnet18', 'resnet50', 'resnet101']), required=True),
    pretrained_ckpt=Param(File(), 'pretrained checkpoint path', required=True),
    write_path=Param(str, 'where to save flm class selection file?', required=True)
)

Section('dataset').params(
    train_path=Param(Or(File(), Folder()), required=True),
    batch_size=Param(int, 'the number of workers', default=1024),
    num_workers=Param(int, 'the number of workers', default=12),
    in_memory=Param(And(int, OneOf([0, 1])), 'does the dataset fit in memory? (0/1)', default=0),
)

Section('sharpness').params(
    metric=Param(
        OneOf(["sam", "shannon", "low_pass", "frob", "fishr", "entropy", "entropy_grad", "eig_avg", "max_eig"]),
        required=True),
)


@param('dataset.train_path')
@param('dataset.batch_size')
@param('dataset.num_workers')
@param('cfg.architecture')
@param('cfg.pretrained_ckpt')
@param('cfg.write_path')
@param('sharpness.metric')
def main(train_path, num_workers, batch_size, architecture, pretrained_ckpt, write_path, metric):

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    file_name = write_path + f"{metric}.pth"
    if os.path.isfile(file_name):
        result = torch.load(file_name, map_location=device)
        print(f"Found existing results in {file_name}: {result.cpu().item()}")
        return

    if architecture == "resnet18":
        from torchvision.models import resnet18
        network_init_func = resnet18
    elif architecture == "resnet50":
        from torchvision.models import resnet50
        network_init_func = resnet50
    elif architecture == "resnet101":
        from torchvision.models import resnet101
        network_init_func = resnet101
    else:
        raise NotImplementedError(f"{architecture} is not supported")
    network = network_init_func().to(device)
    network.eval()

    state_dict = torch.load(pretrained_ckpt, map_location=device)["state_dicts"]["network"]
    network.load_state_dict(state_dict)

    import src.data.utils
    if src.data.utils.check_ffcv_available_from_path(train_path):
        from src.data.ffcv_downstream import get_train_loader
        train_loader, _ = get_train_loader(
            path=train_path, num_workers=num_workers, batch_size=batch_size, res=224, device=device,
            decoder_kwargs={
                'scale': (1, 1),
                'ratio': (1, 1),
            },
            flip_probability=0.
        )
    else:
        train_loader, _ = src.data.utils.get_train_loader_from_path(train_path, num_workers, batch_size, 224, augments=False)

    sharpness = measure_sharpness(metric, network, train_loader)
    sharpness = torch.tensor(sharpness)
    print(sharpness.item(), file_name)
    torch.save(sharpness, file_name)


if __name__ == '__main__':
    config = get_current_config()
    parser = argparse.ArgumentParser("Sharpness Evaluation")
    config.augment_argparse(parser)
    config.collect_argparse_args(parser)
    config.validate()
    config.summary()
    main()
