import argparse
import json
import logging
import os
import sys
from pathlib import Path
from typing import Union, Optional, List
from dataclasses import dataclass, asdict
import torch
import numpy as np

from patching_gemma import logger
from patching_gemma.tasks import NAME_TO_TASK
from patching_gemma.models import NAME_TO_MODEL

@dataclass
class Config:
    task: str
    model_name: str
    batch_size: int
    output_path: str
    limit: Optional[int] = None
    prune_using_imp_scores: Optional[List[str]] = None
    prune_k: Optional[List[int]] = None
    affect_whom: list = None

CONFIGS = {}

def evaluate(config: Config) -> None:
    logger.setLevel(getattr(logging, "DEBUG"))
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    logger.addHandler(ch)
    logger.info(f"Verbosity set to DEBUG")
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    assert config.output_path, "Specify --output_path"

    path = Path(config.output_path)
    if path.exists():
        logger.warning(f"output path {str(path)} already exists, files there will be slowly overwritten")

    path.mkdir(parents=True, exist_ok=True)
    run_params_file = path.joinpath("run_params.json")
    logging_dir = path.joinpath("logging_dir")
    logging_dir.mkdir(parents=True, exist_ok=True)

    task_name = config.task
    batch_size = config.batch_size

    if config.prune_using_imp_scores is not None:
        assert config.prune_k is not None, "provide k according to which to prune"
        assert len(config.prune_using_imp_scores) == len(config.prune_k), "provide as many k values as imp scores to prune according to"

    with open(run_params_file, "w", encoding="utf-8") as file:
        json.dump({
            "model_name": config.model_name,
            "task_name": task_name,
            "batch_size": batch_size,
            "limit": config.limit,
            "prune_using_imp_scores": config.prune_using_imp_scores,
            "prune_k": config.prune_k,
            "affect_whom": str(config.affect_whom)
        }, file, indent=4)

    assert (
        task_name
    ), "No tasks specified, or no tasks found. Please verify the task names."

    model = NAME_TO_MODEL[config.model_name]().create_model()
    logger.debug(f"Created model")

    task_instance = NAME_TO_TASK[task_name]()
    logger.debug(f"Spotted task {task_instance.config.name}")
    logger.debug(f"Task params:\nfewshot{task_instance.fewshot}\nseed{task_instance.fewshot_seed}\n" + 
                 f"fewshot_space{repr(task_instance.fewshot_space)}\nsep{repr(task_instance.sep)}")

    save_path = logging_dir.joinpath(task_instance.config.name)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    args_to_model = {
        "task": task_instance,
        "limit": config.limit,
        "batch_size": batch_size,
        "log_dir": save_path,
    }
    if config.affect_whom is not None:
        args_to_model["affect_whom"] = config.affect_whom
    if config.prune_using_imp_scores is not None:
        args_to_model["prune_using_imp_scores"] = config.prune_using_imp_scores
    if config.prune_k is not None:
        args_to_model["prune_k"] = config.prune_k

    logger.debug("Start running the model")
    model.run(**args_to_model)

    with open(Path(logging_dir).joinpath("model_logs.json"), "w") as file:
        json.dump(model.model_logs, file, indent=4)

    model.break_out()

    del model

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--config", "-c")
    args = parser.parse_args()
    configs = CONFIGS[args.config]
    for config in configs:
        evaluate(config)

