import os
import shutil
import torch
import jax

def run_task():
    """
    This script runs training, finetuning, and matching tasks for a BERT-style model on the IMDB Review dataset.
    Modify the configuration variables below to control the execution.
    """

    # --- Configuration ---
    # Set the root directory for logs, checkpoints, and data.
    log_base_dir = "/root/log/dbpedia"
    data_dir = "/root/log/dbpedia/data" # This directory should contain IMDB_Dataset.csv

    # --- Task and Model Parameters ---
    # flag: 0 for train, 1 for finetune, 2 for matching
    flag = 2
    num_layers = 6
    num_heads = 4
    finetune_layer_which = "all"  # e.g., 0, 1, ... or 'all' depending on the finetune script
    rope_use = True  # Whether to use Rotary Position Embedding

    # --- Environment Setup ---
    # Ensure the base directories for logs and data exist
    os.makedirs(log_base_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)
    
    # String and flag for handling RoPE variations in filenames and arguments
    rope_str = "_rope" if rope_use else ""
    rope_flag = "--rope-use" if rope_use else "" # Assumes finetune/matching scripts accept this flag

    # Check for CUDA availability and print GPU details
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    if cuda_available:
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

    print("JAX running on:", jax.devices())

    # --- Task Execution ---
    tasks = ['train', 'finetune', 'matching']
    task = tasks[flag]
    print(f"\n--- Running Task: {task.upper()} ---")

    if task == 'train':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)

        # The script name depends on whether RoPE is used
        train_script = f'dbpedia_bert_train{rope_str}'
        command = (f'python -m src.dbpedia.{train_script} '
                   f'--seed 0 --num-layers {num_layers} --num-heads {num_heads} '
                   f'--ckpt-path {directory} --train-dataset-path {data_dir}/train.parquet --test-dataset-path {data_dir}/val.parquet')
        print(f"Running command: {command}")
        os.system(command)

    elif task == 'finetune':
        directory = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        
        pretrained_ckpt_dict = {
            "2_4": "dbpedia_bert_seed0_opt_adamw_lr_0.0005_num_layers_2_num_heads_4_hidden_dim_192_embedding_dim_48_epoch4_trainloss_0.2972_testloss_0.2892_trainacc_0.9174_testacc_0.9230",
            "2_8": "dbpedia_bert_seed0_opt_adam_lr_0.001_num_layers_2_num_heads_8_hidden_dim_192_embedding_dim_48_epoch5_trainloss_0.1560_testloss_0.2547_trainacc_0.9537_testacc_0.9327",
            "6_4": "dbpedia_bert_seed0_opt_adam_lr_0.001_num_layers_6_num_heads_4_hidden_dim_192_embedding_dim_48_epoch5_trainloss_0.1540_testloss_0.2513_trainacc_0.9542_testacc_0.9347",
            "6_8": "dbpedia_bert_seed0_opt_adam_lr_0.001_num_layers_6_num_heads_8_hidden_dim_192_embedding_dim_48_epoch4_trainloss_0.1933_testloss_0.2628_trainacc_0.9432_testacc_0.9318",
            "2_4_rope": "dbpedia_bert_seed0_opt_adamw_lr_0.0005_num_layers_2_num_heads_4_hidden_dim_192_embedding_dim_48_epoch5_trainloss_0.3008_testloss_0.3349_trainacc_0.9142_testacc_0.9111",
            "2_8_rope": "dbpedia_bert_seed0_opt_adamw_lr_0.0005_num_layers_2_num_heads_8_hidden_dim_192_embedding_dim_48_epoch5_trainloss_0.3154_testloss_0.3451_trainacc_0.9108_testacc_0.9080",
            "6_4_rope": "dbpedia_bert_seed0_opt_adamw_lr_0.0005_num_layers_6_num_heads_4_hidden_dim_192_embedding_dim_48_epoch5_trainloss_0.2998_testloss_0.3370_trainacc_0.9144_testacc_0.9093",
            "6_8_rope": "dbpedia_bert_seed0_opt_adamw_lr_0.0005_num_layers_6_num_heads_8_hidden_dim_192_embedding_dim_48_epoch5_trainloss_0.3019_testloss_0.3534_trainacc_0.9132_testacc_0.9060",
        }
        
        pretrain_dir = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        ckpt_filename = pretrained_ckpt_dict[f"{num_layers}_{num_heads}{rope_str}"]
        pretrained_ckpt_path = os.path.join(pretrain_dir, ckpt_filename)

        for i in range(1, 4): # Finetune with 3 different seeds
            command = (f'python -m src.dbpedia.dbpedia_bert_finetune_attn {rope_flag} '
                       f'--seed {i} --optimizer adam --learning-rate 0.001 '
                       f'--num-layers {num_layers} --num-heads {num_heads} '
                       f'--finetune-layer-which {finetune_layer_which} '
                       f'--model-path {pretrained_ckpt_path} --ckpt-path {directory} '
                       f'--train-dataset-path {data_dir}/train.parquet --test-dataset-path {data_dir}/val.parquet')
            print(f"Running command: {command}")
            os.system(command)

    elif task == 'matching':
        all_head_permu = False
        all_head_permu_str = "_all_heads_permu" if all_head_permu else ""
        ckpt_path = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}')
        ckpt_dict = {}
        # NOTE: Filename search pattern updated to 'imdb_bert'
        for filename in os.listdir(ckpt_path):
            if 'dbpedia_bert_attn_finetune_seed' in filename:
                seed = filename.split('seed')[1].split('_')[0]
                ckpt_dict[seed] = filename
        
        seeds = list(range(1, 4))
        print(f"Found finetuned models for seeds: {list(ckpt_dict.keys())}")

        for i in range(len(seeds)):
            for j in range(i + 1, len(seeds)):
                seed_a, seed_b = seeds[i], seeds[j]
                plot_path = os.path.join(log_base_dir, f'plots{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}', f'{seeds[i]}{seeds[j]}{all_head_permu_str}')
                if os.path.exists(plot_path): shutil.rmtree(plot_path)
                os.makedirs(plot_path, exist_ok=True)
                
                model_a_path = os.path.join(ckpt_path, ckpt_dict[str(seed_a)])
                model_b_path = os.path.join(ckpt_path, ckpt_dict[str(seed_b)])

                command = f'python -m src.dbpedia.dbpedia_bert_matching_attn{all_head_permu_str} {rope_flag} --num-layers {num_layers} --finetune-layer-which {finetune_layer_which} --num-heads {num_heads} --model-a "{model_a_path}" --model-b "{model_b_path}" --plot-path {plot_path} --train-dataset-path {data_dir}/val.parquet --test-dataset-path {data_dir}/val.parquet'

                print(f"Running command: {command}")
                os.system(command)

if __name__ == "__main__":
    run_task()