import json
import logging
import os

import numpy as np
import pytorch_lightning as pl
import pytorch_lightning.callbacks as pl_callbacks
import sklearn.model_selection
import torch
import torch.utils.data
import torchmetrics.functional as metrics
from tqdm import tqdm

import src.utils.eval as detect_eval
from src import config
from src.selection.benchmarks.argparser import ArgumentsBase
from src.selection.scores import get_selector_score
from src.utils import helpers
from src.utils.vision.datamodules.basic import DefaultDataModule
from src.utils.vision.models import get_model

logger = logging.getLogger(__name__)


class AbstentionArguments(ArgumentsBase):
    test_size: float = 0.9

    @property
    def results_filename(self) -> str:
        return os.path.join(config.RESULTS_DIR, "abstention.csv")


def main(args: AbstentionArguments):
    # Reproducibility
    pl.seed_everything(args.seed, workers=True)

    args.save(os.path.join(args.output_dir, "args.json"), with_reproducibility=False)

    # pre-trained model
    model = get_model(args.model_name, url=args.url)

    score = get_selector_score(args.score_name, **args.score_kwargs)

    # datamodule
    datamodule = DefaultDataModule(
        args.dataset_name,
        args.data_dir,
        args.num_workers,
        args.val_split,
        args.batch_size,
        args.seed,
        args.pin_memory,
        False,
        model.test_transforms,
        model.test_transforms,
        model.test_transforms,
    )

    # predictions
    callbacks = [pl_callbacks.TQDMProgressBar(refresh_rate=10)]
    tb_logger = False
    predictor = pl.Trainer(
        default_root_dir=args.output_dir,
        logger=tb_logger,
        accelerator="auto",
        enable_checkpointing=False,
        auto_select_gpus=True,
        callbacks=callbacks,
        benchmark=True,
    )

    datamodule.setup()

    train_preds = predictor.predict(
        model,
        dataloaders=datamodule.train_dataloader(),
        return_predictions=True,
    )

    train_logits = torch.cat([v["logits"] for v in train_preds])
    train_pred = train_logits.argmax(1)
    train_targets = torch.cat([v["targets"] for v in train_preds])

    train_acc = metrics.accuracy(train_pred, train_targets).item()
    logger.info("Train accuracy: %s", train_acc)
    logger.info("Train logits shape: %s", train_logits.shape)
    logger.info("Train targets shape: %s", train_targets.shape)

    test_preds = predictor.predict(
        model,
        dataloaders=datamodule.test_dataloader(),
        return_predictions=True,
    )
    test_logits = torch.cat([v["logits"] for v in test_preds])
    test_pred = test_logits.argmax(1)
    test_targets = torch.cat([v["targets"] for v in test_preds])

    test_acc = metrics.accuracy(test_pred, test_targets).item()
    logger.info("Test accuracy: %s", test_acc)
    logger.info("Test logits shape: %s", test_logits.shape)
    logger.info("Test targets shape: %s", test_targets.shape)

    # scores
    logger.info("Calculating scores...")

    test_y_scores = score.forward_logits(test_logits)
    # calibration
    logger.info("Calibrating with 10% test samples...")
    cal_indexes, test_indexes = sklearn.model_selection.train_test_split(
        np.arange(len(test_y_scores)),
        test_size=args.test_size,
        random_state=args.seed,
    )
    logger.info("Calibration indexes len: %s", len(cal_indexes))
    logger.info("Test indexes len: %s", len(test_indexes))

    (
        cal_risks,
        cal_coverages,
        cal_thrs,
    ) = detect_eval.risk_coverage.risks_coverages_selective_net(
        test_y_scores[cal_indexes],
        test_logits[cal_indexes],
        test_targets[cal_indexes],
    )

    target_coverages = np.linspace(0.7, 1, 7).tolist()
    for target_coverage in tqdm(target_coverages, "Test target coverages"):
        _thr = cal_thrs[torch.searchsorted(cal_coverages, target_coverage)]
        test_sn_coverages = detect_eval.risk_coverage.hard_coverage(
            test_y_scores[test_indexes], _thr
        ).item()

        test_sn_risks = (
            100
            * detect_eval.risk_coverage.selective_net_risk(
                test_y_scores[test_indexes],
                test_logits[test_indexes],
                test_targets[test_indexes],
                _thr,
            ).item()
        )

        results_obj = {
            "model": args.model_name,
            "dataset": args.dataset_name,
            "score": args.score_name,
            "test_size": args.test_size,
            "kwargs": args.score_kwargs,
            "model_train_acc": train_acc,
            "model_test_acc": test_acc,
            "target_coverages": target_coverages,
            "test_sn_coverages": test_sn_coverages,
            "test_sn_risks": test_sn_risks,
            "seed": args.seed,
        }
        logger.info("Results: %s", json.dumps(results_obj, indent=2))

        # save results
        helpers.append_results_to_file(results_obj, filename=args.results_filename)


if __name__ == "__main__":
    args = AbstentionArguments().parse_args()

    logging.basicConfig(
        format="---> %(levelname)s - %(name)s - %(message)s",
        level=logging.DEBUG if args.debug else logging.INFO,
    )
    logger.info(args)
    main(args)
