# run inference and save scores
import argparse
import json
import logging
import os

import timm
import timm.data
import torch

import src
from src.config import RESULTS_DIR
import accelerate

_logger = logging.getLogger(__name__)


def main(args):
    print(f"Running {args.pipeline} pipeline on {args.model} model")
    # set global seed
    accelerate.utils.set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # create model
    model = timm.create_model(args.model, pretrained=True)
    model.to(device)
    data_config = timm.data.resolve_data_config(model.default_cfg)
    test_transform = timm.data.create_transform(**data_config)
    _logger.info("Test transform: %s", test_transform)
    # create pipeline
    pipeline = src.create_pipeline(
        args.pipeline,
        mix_factor=args.mix_factor,
        warmup_size=args.warmup_size,
        batch_size=args.batch_size,
        seed=args.seed,
        transform=test_transform,
        limit_fit=args.limit_fit,
        num_workers=args.workers,
        prefetch_factor=args.prefetch,
    )
    pipeline.setup()

    if "vit" in args.model and "pooling_op_name" in args.method_kwargs:
        args.method_kwargs["pooling_op_name"] = "getitem"

    # create detector
    method = src.create_detector(args.method, model=model, **args.method_kwargs)
    # run pipeline
    pipeline_score_results = pipeline.run(method)
    pipeline_score_results = {k: v.numpy().tolist() for k, v in pipeline_score_results.items()}

    if not args.debug:
        # save results to file
        results = {
            "pipeline": args.pipeline,
            "model": args.model,
            "in_dataset_name": pipeline.in_dataset_name,
            "drift_dataset_name": pipeline.drift_datasets_names,
            "method": args.method,
            "method_kwargs": args.method_kwargs,
            "warmup_size": args.warmup_size,
            "seed": args.seed,
            **pipeline_score_results,
        }
        filename = os.path.join(RESULTS_DIR, args.pipeline, "scores.csv")
        src.utils.append_results_to_csv_file(results, filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, default="msp")
    parser.add_argument("--method_kwargs", type=json.loads, default={}, help='{"temperature":1000, "eps":0.00014}')

    parser.add_argument("--pipeline", type=str, default="drift_benchmark_imagenet_r")
    parser.add_argument("--warmup_size", type=int, default=10000)
    parser.add_argument("--mix_factor", type=float, default=1.0)
    parser.add_argument("--enable_preds", action="store_true")

    parser.add_argument("--model", type=str, default="tv_resnet50")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--limit_fit", type=float, default=0.05)
    parser.add_argument("--workers", type=int, default=8)
    parser.add_argument("--prefetch", type=int, default=2)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    args.seed = 1

    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
    _logger.info(json.dumps(args.__dict__, indent=2))
    if args.method == "react_projection":
        args.method_kwargs["features_nodes"] = ["layer1", "layer2", "layer3", "clip", "fc"]
    if "vit" in args.model and "projection" in args.method:
        args.method_kwargs["features_nodes"] = [f"blocks.{l}" for l in range(1, 12)] + ["fc_norm", "head"]
    if args.method == "msp":
        args.enable_preds = True
    main(args)
