from argparse import ArgumentParser
from logging import Logger
from typing import Dict

import torch
import torchmetrics
import yaml
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda, PILToTensor, ToTensor
from tqdm import tqdm

from sde.datasets import IndexedImageFolder
from sde.models import load_synthetic_model
from sde.utils import DictAction, merge_from_options, read_config, setup_logger


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("config", type=str, help="path to synthetic setting config file.")
    parser.add_argument('--gpu-id', type=int, default=0, help='GPU device ID.')
    parser.add_argument(
        '--cfg-options', nargs='+', action=DictAction, help='Key value pairs xxx=yyy to override config options.')

    return parser.parse_args()


@torch.no_grad()
def eval_synthetic_setup(cfg: Dict, device: torch.device, logger: Logger) -> None:
    # load model
    model = load_synthetic_model(**cfg['model']).to(device).eval()

    dataset_cfg = cfg['dataset']
    if dataset_cfg['num_channels'] == 3:
        transform = Compose([PILToTensor(), Lambda(lambda x: x.float())])
    else:
        transform = ToTensor()
    dataset = IndexedImageFolder(
        root=dataset_cfg["root"], num_channels=dataset_cfg['num_channels'], transform=transform)
    num_classes = len(dataset.classes)
    dataloader = DataLoader(dataset, **cfg['data_loader'], persistent_workers=True)

    # check if the model has 1 accuracy on the dataset
    metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
    for data in tqdm(dataloader):
        inputs, labels = data
        preds = model(inputs.to(device))
        # fixing numerical issue: due to the non-uniform init of accumulator, the prediction
        # sometimes are like 9.001, and labels are 9.
        preds = torch.round(preds)
        _ = metric(preds.detach().cpu().squeeze(), labels)
    acc = metric.compute()
    logger.info(f'Accuracy: {acc:.4f}')
    metric.reset()


if __name__ == "__main__":
    args = parse_args()
    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, args.cfg_options)

    logger = setup_logger("sde")
    logger.info(f'Using config:\n{yaml.dump(cfg, indent=4, sort_keys=False)}\n' + '-' * 60)

    device = torch.device(f'cuda:{args.gpu_id}')
    eval_synthetic_setup(cfg, device=device, logger=logger)
