import yaml
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
from matplotlib.axis import GRIDLINE_INTERPOLATION_STEPS
from torch import cuda

from src.pipelines.evaluation import evaluation
from src.pipelines.evaluation_ray_tune_errica import errica_evaluation_ray_tune
from src.pipelines.evaluation_ray_tune import evaluation_ray_tune
from src.pipelines.evaluation_early_stopping import evaluation_cross_val
from src.pipelines.errica import errica_cross_testing
from src.pipelines.cross_validation import cross_validation_pipeline
from src.pipelines.training import training
from src.pipelines.grid_search import grid_search
from src.utils.config_utils import validate_config
from src.utils.command_parser import setup_parser, GRID_SEARCH_COMMAND, EVAL_COMMAND, ERRICA_COMMAND, EVAL_RAY_COMMAND, \
    EVAL_ERRICA_COMMAND
from src.utils.command_parser import TRAIN_COMMAND, CROSSVAL_COMMAND
from src.utils.seed import set_seed


def main() -> None:
    """
    The main function of the script. It sets up the parser, reads the configuration files,
    and starts the training or prediction process based on the provided pipeline argument.

    Usage examples:
        python -m src_transformers.main --config data/test_configs/training_config.yaml --pipeline train
        python -m src_transformers.main -c data/test_configs/training_config.yaml -p train
    """

    # Parse command line arguments
    parser = setup_parser()
    args, _ = parser.parse_known_args()
    with open(args.config, "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    # Check validity of the config file values
    validate_config(config)

    run_name = config.pop("run_name")

    # Setting up GPU based on availability and usage preference
    device = config['cuda_device']
    if "cuda" in device:
        try:
            device = torch.device(device)
            device_name = cuda.get_device_name(device)
            print(f"[MAIN]: Using the device '{device_name}' for the started pipeline.")
            # if device_name == "NVIDIA RTX 6000 Ada Generation":
            #     os.environ["CUDA_VISIBLE_DEVICES"] = "0"
            #     if "CUDA_VISIABLE_DEVICES" in os.environ.keys():
            #         print(f"[MAIN]: CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})")
            if str(device_name) == "NVIDIA RTX A6000":
                os.environ["CUDA_VISIBLE_DEVICES"] = "0"
                print("[MAIN]: Using GPU 0 (CUDA_VISIBLE_DEVICES=0) for the started pipeline.")
            if "CUDA_VISIABLE_DEVICES" in os.environ.keys():
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                print(f"[MAIN]: CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})")
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            x = torch.tensor([1.0], device=device)
            print(f"[MAIN]: Device {device} is available.")
        except RuntimeError as e:
            print(f"[MAIN]: Device {device} is not available:", e)
            device = torch.device('cpu')
            print("[MAIN]: Using the CPU for the started pipeline.")
    else:
        device = torch.device('cpu')
        print("[MAIN]: GPU was deactivated, using the CPU for the started pipeline.")
    resources = {"gpu":1} if "resources" not in config.keys() else config["resources"]
    max_concurrent_trails = 1 if "max_concurrent_trails" not in config.keys() else config["max_concurrent_trails"]

    # Setting random seed for torch
    seed = config["training_parameters"].pop("seed")
    set_seed(seed, device)



    if args.pipeline == TRAIN_COMMAND:
        training(
            config=config,
            run_name=run_name,
            device=device,
        )

    if args.pipeline == CROSSVAL_COMMAND:
        # get indices of train and test patients
        cross_validation_pipeline(
            config=config,
            run_name=run_name,
            device=device,
        )

    if args.pipeline == GRID_SEARCH_COMMAND:
        grid_search(
            config=config,
            run_name=run_name,
            resources=resources,
            max_concurrent_trails=max_concurrent_trails,
            device=device,
        )

    if args.pipeline == EVAL_COMMAND:
        evaluation(
            config=config,
            run_name=run_name,
            device=device
        )

    if args.pipeline == ERRICA_COMMAND:
        errica_cross_testing(
            config=config,
            run_name=run_name,
            resources=resources,
            max_concurrent_trails=max_concurrent_trails,
            device=device
        )

    if args.pipeline == EVAL_RAY_COMMAND:
        evaluation_ray_tune(
            config=config,
            run_name=run_name,
            resources=resources,
            max_concurrent_trails=max_concurrent_trails,
            device=device
        )

    if args.pipeline == EVAL_ERRICA_COMMAND:
        errica_evaluation_ray_tune(
            config=config,
            run_name=run_name,
            resources=resources,
            max_concurrent_trails=max_concurrent_trails,
            device=device
        )





if "__main__" == __name__:
    main()
