import os
import torch
import random
import numpy as np
import yaml
from train import train
import warnings
from transformers import logging as hf_logging

from transformers import CanineModel, BertModel, CanineTokenizer

current_dir = os.path.abspath(os.path.dirname(__file__))
one_levels_up = os.path.dirname(current_dir)


# os.environ["CUDA_VISIBLE_DEVICES"] = '3'


def set_seed(seed):
    """
    Set the seed for reproducibility.

    Args:
    seed (int): The seed value.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_config(config_file):
    """
    Load the configuration from a YAML file.

    Args:
    config_file (str): Path to the configuration file.

    Returns:
    dict: Configuration dictionary.
    """
    with open(config_file, 'r') as file:
        config = yaml.safe_load(file)
    if 'learning_rate' in config:
        config['learning_rate'] = float(config['learning_rate'])
    return config


def print_args(args):
    print("Arguments:")
    for key, value in vars(args).items():
        print(f"{key}: {value}")
    # print("\n")


def resolve_config_path(config_arg):
    """Accept either a filename or a full/relative path."""
    if os.path.isfile(config_arg):
        return config_arg

    config_in_configs_dir = os.path.join('configs', config_arg)
    if os.path.isfile(config_in_configs_dir):
        return config_in_configs_dir

    raise FileNotFoundError(
        f"Config file not found: '{config_arg}'. "
        "Pass a file in configs/ (e.g., config_v10.yaml) or a valid path."
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run training with a manual seed for reproducibility.")
    parser.add_argument('--config', type=str, default='config_v10.yaml',
                        help='Config filename (in configs/) or a direct path')

    args = parser.parse_args()

    config_file_path = resolve_config_path(args.config)
    config = load_config(config_file_path)

    set_seed(config['seed'])


    # Convert config dictionary to an object with attributes
    class Config:
        def __init__(self, **entries):
            self.__dict__.update(entries)


    config = Config(**config)

    # Suppress specific warnings
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    warnings.filterwarnings("ignore", category=UserWarning, message=".*parallelism has already been used.*")

    # Suppress HuggingFace transformers logging
    hf_logging.set_verbosity_error()

    # Print the argument values
    print_args(config)

    train(config)
