import argparse
from utils.constants import LIST_OF_MODELS_DC, LIST_OF_DATASETS_DC, PROBE_MODELS, MAXIMAL_VOCAB_SIZE, LIST_OF_DATASETS_HD, LIST_OF_MODELS_HD, LIST_OF_ALL_MODELS, LIST_OF_ALL_DATASETS


def parse_args_token_layers_scores():
    parser = argparse.ArgumentParser(
        description="Parse arguments for token layers scores."
    )
    parser.add_argument(
        "--LLM",
        choices=LIST_OF_ALL_MODELS,
        default="meta-llama/Meta-Llama-3-8B-Instruct",
        # default="huggyllama/llama-13b",
        help="Pretrained model of which its logits/activations are used."
    )
    parser.add_argument(
        "--dataset",
        choices=LIST_OF_ALL_DATASETS,
        default='movies',
        help="Dataset to be processed."
    )
    parser.add_argument(
        "--fold_to_run", 
        type=int, 
        default=0,
        help="fold to run (default: 0)."
    )
        
    parser.add_argument(
        "--base_raw_data_dir",
        type=str,
        # default='./raw_data',
        # default='/mnt/storage/user/raw_data',
        default='/mnt/storage/user/raw_data',
        help="Base directory for saving raw data."
    )
    parser.add_argument(
        "--input_output_type", 
        type=str, 
        default="output",
        # default="input",
        choices=["input", "output"],
        help="Usage of input or output."
    )
    parser.add_argument(
        "--num_folds", 
        type=int, 
        default=5,
        help="Number of folds to run (default: 5)."
    )
    
    parser.add_argument(
        "--n_samples", 
        type=int, 
        default=10000,
        help="Number of folds to run (default: 5)."
    )
    
    parser.add_argument(
        "--size_limit",
        type=int,
        default=10000,
        help="Number of folds to run (default: 10000)."
    )
    
    return parser.parse_args()

def parse_args_DC():
    """
    Parse command-line arguments for the script.
    
    Returns:
    --------
    argparse.Namespace:
        Parsed command-line arguments with dataset, model, and split details.
    """
    parser = argparse.ArgumentParser(description="Generate model activations and labels from a specified dataset.")
    
    # Argument for selecting the model
    parser.add_argument(
        "--LLM",
        choices=LIST_OF_MODELS_DC,
        default="EleutherAI/pythia-6.9b",
        help="Pretrained model to use for generating activations."
    )
    
    # Argument for selecting the dataset
    parser.add_argument(
        "--dataset",
        choices=LIST_OF_DATASETS_DC,
        default='BookMIA_128',
        help="Dataset to be processed."
    )
    
    
    parser.add_argument(
        "--take_top_k",
        type=int,
        default=MAXIMAL_VOCAB_SIZE,
        help="Top-K to use when extracting the raw dataset -- should be max over all vocab sizes (default: 1_000_000)."
    )
    
    parser.add_argument(
        "--base_raw_data_dir",
        type=str,
        # default='./raw_data',
        default='/mnt/storage/user/raw_data',
        help="Base directory for saving raw data."
    )
    
    return parser.parse_args()


def parse_args_HD():
    """
    Parse command-line arguments for the script.
    
    Returns:
    --------
    argparse.Namespace:
        Parsed command-line arguments with dataset, model, and split details.
    """
    parser = argparse.ArgumentParser(description="Generate model activations and labels from a specified dataset.")
    
    # Argument for selecting the model
    parser.add_argument(
        "--LLM",
        choices=LIST_OF_MODELS_HD,
        default='meta-llama/Meta-Llama-3-8B-Instruct',
        # default="Qwen/Qwen2.5-7B-Instruct",
        help="Pretrained model to use for generating activations."
    )
    
    # Argument for selecting the dataset
    parser.add_argument(
        "--dataset",
        choices=LIST_OF_DATASETS_HD,
        default='triviaqa_test',
        help="Dataset to be processed."
    )
    parser.add_argument(
        "--n_samples", 
        type=int, 
        help='number of examples to use', 
        default=10_000
        )
    parser.add_argument(
        "--chunk", 
        type=int, 
        default=1,
        )
    
    parser.add_argument(
        "--base_raw_data_dir",
        type=str,
        # default='./raw_data',
        default='/mnt/storage/user/raw_data',
        # default='/home/user/ACT-ViT/Text_data',
        help="Base directory for saving raw data."
    )
    
    parser.add_argument(
        "--take_top_k",
        type=int,
        default=MAXIMAL_VOCAB_SIZE,
        help="Top-K to use when extracting the raw dataset -- should be max over all vocab sizes (default: 1_000_000)."
    )
    
    parser.add_argument(
        "--save_text_only",
        type=int,
        choices=[0, 1],
        default=0,
        help="Set to 1 if you want to save the prompts and responses only."
    )
    

    
    return parser.parse_args()


def parse_args_main(foundation=False, few_shot_adaptation=False):
    """
    Parses command-line arguments for training, model, dataset, and logging configurations.
    
    Returns:
        argparse.Namespace: Parsed arguments.
    """
    parser = argparse.ArgumentParser(
        description="Parse arguments for training and evaluation pipeline."
    )
    
    if foundation:
        parser.add_argument(
            '--train_dataset',
            action=DatasetComboAction,
            default=[
                # ('Qwen/Qwen2.5-7B-Instruct', 'imdb'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'hotpotqa'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'triviaqa'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'movies'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'hotpotqa_with_context'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'imdb'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'hotpotqa'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'triviaqa'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'movies'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'hotpotqa_with_context'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'imdb'),
                ('meta-llama/Meta-Llama-3-8B-Instruct', 'hotpotqa'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'triviaqa'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'movies'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'hotpotqa_with_context')
            ]
            ,
            # , ('Qwen/Qwen2.5-7B-Instruct', 'hotpotqa'), ('mistralai/Mistral-7B-Instruct-v0.2', 'imdb'), ('meta-llama/Meta-Llama-3-8B-Instruct', 'hotpotqa'), ('meta-llama/Meta-Llama-3-8B-Instruct', 'imdb')],
            help="Training datasets as a single string (format: 'llm:dataset dataset ... llm:dataset ...')"
        )
    
        parser.add_argument(
            '--test_dataset',
            action=DatasetComboAction,
            default=[
                # ('Qwen/Qwen2.5-7B-Instruct', 'imdb'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'hotpotqa'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'triviaqa'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'movies'),
                # ('Qwen/Qwen2.5-7B-Instruct', 'hotpotqa_with_context'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'imdb'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'hotpotqa'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'triviaqa'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'movies'),
                # ('mistralai/Mistral-7B-Instruct-v0.2', 'hotpotqa_with_context'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'imdb'),
                ('meta-llama/Meta-Llama-3-8B-Instruct', 'hotpotqa'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'triviaqa'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'movies'),
                # ('meta-llama/Meta-Llama-3-8B-Instruct', 'hotpotqa_with_context')
            ]
            ,
            # , ('mistralai/Mistral-7B-Instruct-v0.2', 'hotpotqa_test'), ('mistralai/Mistral-7B-Instruct-v0.2', 'imdb_test') ,('meta-llama/Meta-Llama-3-8B-Instruct', 'hotpotqa_test'), ('meta-llama/Meta-Llama-3-8B-Instruct', 'imdb_test')],
            help="Test datasets as a single string (format: 'llm:dataset dataset ... llm:dataset ...')"
        )
    else:
        parser.add_argument(
            "--LLM",
            choices=LIST_OF_ALL_MODELS,
            # default="mistralai/Mistral-7B-Instruct-v0.2",
            default="Qwen/Qwen2.5-7B-Instruct",
            # default="meta-llama/Meta-Llama-3-8B-Instruct",
            # default="state-spaces/mamba-1.4b-hf",
            # default="huggyllama/llama-30b",
            # default="EleutherAI/pythia-12b",
            help="Pretrained model of which its logits/activations are used."
        )
        
        parser.add_argument(
            "--train_dataset", 
            type=str,
            # default="WikiMIA_64",
            # default="BookMIA_128",
            default="hotpotqa",
            help="Train dataset (default: 'WikiMIA_32')."
        )
    
        parser.add_argument(
            "--test_dataset", 
            type=str,
            default="hotpotqa_test",
            help="Test dataset (default: 'WikiMIA_32')."
        )

    if few_shot_adaptation:
        parser.add_argument(
            "--size_limit",
            type=int,
            default=10000,
            help="Toatal size limit for the CombinedCustomDataset (default: 10000). each (llm,dataset) will have this size limit divided by the number of (llm,dataset) in the training set."
        )
        # checkpoint_path
        parser.add_argument(
            "--wandb_link",
            type=str,
            default="https://wandb.ai/wandb_user/ACT-ViT/sweeps/fqhuz8iu",
            help="Checkpoint path (default: 'saved_models/a17281ac4c6/47771449509_best_model.pth')."
        )
        
    parser.add_argument(
        "--base_raw_data_dir",
        type=str,
        # default='./raw_data',
        default='/mnt/storage/user/raw_data',
        help="Base directory for saving raw data."
    )  

    
    parser.add_argument(
        "--base_pre_processed_data_dir", 
        type=str, 
        # default='/home/user/big-storage/pre_processed_data',
        default='/home/user/ACT-ViT/pre_processed_data',
        help="Base directory for saving pre processed data."
    )
    
    parser.add_argument(
        "--probe_model", 
        choices=PROBE_MODELS, 
        # default="ACT-ViT",
        # default="ATP_R_MLP",
        # default="ATP_R_Transf",
        default="ACT-Vit",
        # default="ACT-Vit-foundation",
        # default="ACT-Vit-with-symmetries",
        # default="ACT-Vit-with-symmetries-V2",
        # default="ACT-Vit-foundation",
        # default="ACT-ViT",
        # default="Logit_Canonized",
        help="The probing model to use (default: 'ACT-ViT')."
    )
    
    parser.add_argument(
        "--topk_preprocess", 
        type=int, 
        # default=MAXIMAL_VOCAB_SIZE,
        default=1_000,
        help="Top-K to load the preprocessed dataset -- should be max over all vocab sizes, or 1_000 (default: 1_000_000)."
    )
    
    
    parser.add_argument(
        "--input_output_type", 
        type=str, 
        # default="input",
        default="output",
        choices=["input", "output"],
        help="Usage of input or output."
    )
    
    parser.add_argument(
        "--topk_dim", 
        type=int, 
        # default=MAXIMAL_VOCAB_SIZE,
        default=1000,
        help="Top-K dimension to actually use for the model, should be 1000, or between 10 to 1000 for ablation study, see paper (default: 1,000,000)."
    )
    
    parser.add_argument(
        "--patch_size",
        type=str,
        default="(1, 1)",  # Default value
        help="A tuple parameter as a string. Format: '(x,y)'. Defaults to '(2,13)'."
    )

    parser.add_argument(
        "--num_folds", 
        type=int, 
        default=5,
        help="Number of folds to run (default: 5)."
    )
        
    parser.add_argument(
        "--fold_to_run", 
        type=int, 
        default=0,
        help="fold to run (default: 0)."
    )
    
    parser.add_argument(
        "--input_type", 
        type=str, 
        choices=["LOS", "activations"], 
        default="activations",
        # default="activations",
        help="fold to run (default: 0)."
    )
    
    parser.add_argument(
        "--seed", 
        type=int, 
        default=0,
        help="seed (default: 0)."
    )
    
    parser.add_argument(
        "--cuda_idx", 
        type=int, 
        default=0,
        help="cuda index (default: 0)."
    )
    
    parser.add_argument(
        "--batch_size", 
        type=int, 
        default=64,
        help="batch size (default: 0)."
    )
    
    parser.add_argument(
        "--hidden_dim", 
        type=int, 
        default=128,
        help="hidden dimension (default: 128)."
    )
    
    parser.add_argument(
        "--heads", 
        type=int, 
        default=4,
        help="number of heads (default: 4)."
    )
    
    parser.add_argument(
        "--dropout", 
        type=float, 
        default=0.3,
        help="dropout to use (default: 0.3)."
    )
    
    parser.add_argument(
        "--num_layers", 
        type=int, 
        default=3,
        help="number of layers to use (default: 1)."
    )
    
    parser.add_argument(
        "--pool", 
        type=str, 
        default='cls',
        help="pooling (default: cls)."
    )
    
    parser.add_argument(
        "--patience", 
        type=int, 
        default=100,
        help="patience (default: 30)."
    )
    
        
    parser.add_argument(
        "--num_epochs", 
        type=int, 
        default=15,
        help="epochs (default: 100)."
    )
    
    parser.add_argument(
        "--best_model_path", 
        type=str, 
        default="/mnt/storage/user/saved_models/saved_models",
        help="Path to save the best model (default: best_model)."
    )
    
    parser.add_argument(
        "--lr", 
        type=float, 
        default=0.001,
        help="Learning rate for the optimizer (default: 0.0001)."
    )
    
    parser.add_argument(
        "--rank_encoding", 
        type=str, 
        # default="one_hot_encoding",
        default="scale_encoding",
        help="The way to use the rank encoding (default: 0.0001)."
    )
    
    
    parser.add_argument(
        "--weight_decay", 
        type=float, 
        default=0.0,
        help="Weight decay for regularization (default: 0.01)."
    )
    
    parser.add_argument(
        "--num_workers", 
        type=int, 
        default=4,
        help="Number of workers (default: 4)."
    )
    parser.add_argument(
        "--pin_memory", 
        type=int, 
        default=1,
        help="pin memory (default: 4)."
    )
    
    parser.add_argument(
        "--down_sample_strategy",
        type=str,
        default="pool",  # Default value
        choices=["none", "LG", "pool"],
    )
    
    parser.add_argument(
        "--L_eff", 
        type=int, 
        default=8,
        help="Effective Layers taken."
    )

    parser.add_argument(
        "--N_eff", 
        type=int, 
        default=100,
        help="Effective Layers taken."
    )



    return parser.parse_args()


def parse_args_pre_process():
    """
    Parse command-line arguments for the script.
    
    Returns:
    --------
    argparse.Namespace:
        Parsed command-line arguments with dataset, model, and split details.
    """
    parser = argparse.ArgumentParser(description="Generate model activations and labels from a specified dataset.")
    
    # Argument for selecting the model
    parser.add_argument(
        "--LLM",
        choices=LIST_OF_ALL_MODELS,
        default="Qwen/Qwen2.5-7B-Instruct",
        # default="mistralai/Mistral-7B-Instruct-v0.2",
        # default="meta-llama/Meta-Llama-3-8B-Instruct",
        help="Pretrained model to use for generating activations."
    )
    
    # Argument for selecting the dataset
    parser.add_argument(
        "--dataset",
        choices=LIST_OF_ALL_DATASETS,
        default='hotpotqa',
        help="Dataset to be processed."
    )
    
    
    
    parser.add_argument(
        "--base_raw_data_dir",
        type=str,
        # default='./raw_data',
        default='/mnt/storage/user/raw_data',
        help="Base directory for saving raw data."
    )
    
    parser.add_argument(
        "--base_pre_processed_data_dir", 
        type=str, 
        # default='/home/user/big-storage/pre_processed_data',
        default='/home/user/ACT-ViT/pre_processed_data',
        help="Base directory for saving pre processed data."
    )
    
    
    parser.add_argument(
        "--topk_preprocess", 
        type=int, 
        default=1_000,
        help="Top-K to use when preprocessing the dataset -- should be max for DCD and 1000 for HD (default: 1_000_000)."
    )
    
    parser.add_argument(
        "--input_output_type", 
        type=str, 
        default="output",
        help="Usage of input or output."
    )
    
    # --
    
        
    parser.add_argument(
        "--down_sample_strategy",
        type=str,
        default="pool",  # Default value
        choices=["none", "LG", "pool"],
    )
    
    
    parser.add_argument(
        "--L_eff", 
        type=int, 
        default=8,
        help="Effective Layers taken."
    )
    
    parser.add_argument(
        "--N_eff", 
        type=int, 
        default=20,
        help="Effective Tokens taken."
    )
    
    parser.add_argument(
        "--input_type", 
        type=str, 
        choices=["LOS", "activations"], 
        default="activations",
        help="fold to run (default: 0)."
    )
    


    return parser.parse_args()



class DatasetComboAction(argparse.Action):
    """Custom action to parse LLM:dataset combinations from a single string"""
    def __call__(self, parser, namespace, values, option_string=None):
        # Expect a single string value
        if not isinstance(values, str):
            raise argparse.ArgumentError(
                self, "Input must be a single quoted string (e.g., 'llm1:data1 data2 llm2:data3')")
        
        dataset_list = []
        current_llm = None
        tokens = values.split()  # Split the string on whitespace
        for value in tokens:
            if ':' in value:
                # New LLM specification
                if value.endswith(':'):
                    raise argparse.ArgumentError(
                        self, "Invalid format: LLM specification cannot end with a colon.")
                llm, dataset = value.split(':', 1)
                if llm not in LIST_OF_ALL_MODELS:
                    raise argparse.ArgumentError(
                        self, f"Invalid LLM: {llm}. Must be one of: {', '.join(LIST_OF_ALL_MODELS)}")
                current_llm = llm
                dataset_list.append((current_llm, dataset))
            else:
                # Dataset using the current LLM
                if current_llm is None:
                    raise argparse.ArgumentError(
                        self, "Must specify LLM before datasets (use 'llm:dataset' format first)")
                dataset_list.append((current_llm, value))
        setattr(namespace, self.dest, dataset_list)