# load scores and compute metrics
import accelerate
import argparse
import json
import logging
import os
import pandas as pd

import torch
from tqdm import tqdm

import src
from src.config import RESULTS_DIR

_logger = logging.getLogger(__name__)


def main(args):
    accelerate.utils.set_seed(args.seed)
    mix_factors = [0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]
    window_sizes = [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100, 200, 300, 500, 800, 1000]
    results_strings = []
    for mix_factor in tqdm(mix_factors):
        pipeline = src.create_pipeline(
            args.pipeline,
            window_sizes=window_sizes,
            mix_factor=mix_factor,
            warmup_size=args.warmup_size,
            num_samples=args.num_samples,
            seed=args.seed,
            transform=None,
        )
        df = pd.read_csv(os.path.join(RESULTS_DIR, args.pipeline, "scores.csv"))
        # query scores for the given model and method
        df = df.query(
            f"model==@args.model and method==@args.method and warmup_size==@args.warmup_size and pipeline==@args.pipeline"
        ).drop_duplicates(keep="last")
        warmup_scores = torch.tensor(json.loads(df["warmup_scores"].values[0]), dtype=torch.float32)
        in_scores = torch.tensor(json.loads(df["in_scores"].values[0]), dtype=torch.float32)
        drift_scores = torch.tensor(json.loads(df["drift_scores"].values[0]), dtype=torch.float32)
        _logger.debug("Loaded scores for model: %s, method: %s", args.model, args.method)
        _logger.debug("In scores: %s", in_scores)
        _logger.debug("Warmup scores: %s", warmup_scores)
        _logger.debug("Drift scores: %s", drift_scores)

        # compute metrics
        assert (
            torch.isnan(warmup_scores).sum() == 0
            and torch.isnan(in_scores).sum() == 0
            and torch.isnan(drift_scores).sum() == 0
        )
        metrics = pipeline.postprocess(warmup_scores, in_scores, drift_scores, criterion=args.criterion)
        results_strings.append(pipeline.report(metrics))

        results = {
            "pipeline": args.pipeline,
            "model": args.model,
            "method": args.method,
            "method_kwargs": args.method_kwargs,
            "seed": args.seed,
            "num_samples": args.num_samples,
            "warmup_size": args.warmup_size,
            "mix_factor": mix_factor,
            "criterion": args.criterion,
            "results": json.dumps(metrics),
        }
        filename = os.path.join(RESULTS_DIR, args.pipeline, "results.csv")
        src.utils.append_results_to_csv_file(results, filename)

    for mix_factor, results_string in zip(mix_factors, results_strings):
        print(f"Mix factor: {mix_factor}")
        print(results_string)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, default="msp")
    parser.add_argument("--method_kwargs", type=json.loads, default={}, help='{"temperature":1000, "eps":0.00014}')

    parser.add_argument("--pipeline", type=str, default="drift_benchmark_imagenet_r")
    parser.add_argument("--warmup_size", type=int, default=10000)
    parser.add_argument("--num_samples", type=int, default=10000)
    parser.add_argument("--criterion", type=str, default="ks_2samp")

    parser.add_argument("--model", type=str, default="tv_resnet50")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--no_verbose", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
    if args.no_verbose:
        _logger.setLevel(logging.ERROR)
    _logger.info(json.dumps(args.__dict__, indent=2))
    main(args)
