import os
import shutil
import torch
import jax

def run_task():
    """
    This script runs different tasks related to a BERT-style model on the AG News dataset.
    You can configure the task and its parameters by modifying the variables inside this function.
    """

    # --- Configuration ---
    # Set the root directory for logs, checkpoints, and data.
    # Modify these to suitable paths on your local machine.
    log_base_dir = "/root/log/agnews"
    data_dir = "/root/log/agnews/data"

    # --- Task and Model Parameters ---
    # flag: 0 for train, 1 for finetune, 2 for matching
    flag = 2
    num_layers = 6
    num_heads = 8
    finetune_layer_which = 0#"all"
    rope_use = True  # Whether to use Rotary Position Embedding

    # Determine rope string for directory naming, but note that the AG News
    rope_str = "_rope" if rope_use else ""
    rope_flag = "--rope-use" if rope_use else ""

    # --- Environment Setup ---
    # Ensure the base directories for logs and data exist
    os.makedirs(log_base_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)

    # Check for CUDA availability and print GPU details
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    if cuda_available:
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

    print("JAX running on:", jax.devices())

    # --- Task Execution ---
    tasks = ['train', 'finetune', 'matching']
    task = tasks[flag]

    if task == 'train':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        #if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        # Note: The original AG News script does not pass a --rope-use flag.
        # The rope_use variable here mainly controls the output directory.
        command = (f'python -m src.agnews.agnews_bert_train{rope_str} --seed 0 --optimizer adam '
                   f'--learning-rate 0.001 --num-layers {num_layers} --num-heads {num_heads} '
                   f'--ckpt-path {directory} --data-path {data_dir}')
        print(f"Running command: {command}")
        os.system(command)

    elif task == 'finetune':
        directory = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)

        pretrained_ckpt_dict = {
            "2_4": "agnews_bert_seed0_opt_adam_lr_0.001_L2_H4_epoch3_trainloss_0.2496_testloss_0.2861_trainacc_0.9124_testacc_0.9064.flax",
            "2_8": "agnews_bert_seed0_opt_adam_lr_0.001_L2_H8_epoch4_trainloss_0.2137_testloss_0.2818_trainacc_0.9250_testacc_0.9051.flax",
            "6_4": "agnews_bert_seed0_opt_adam_lr_0.001_L6_H4_epoch5_trainloss_0.1950_testloss_0.3011_trainacc_0.9314_testacc_0.9052.flax",
            "6_8": "agnews_bert_seed0_opt_adam_lr_0.001_L6_H8_epoch5_trainloss_0.1937_testloss_0.2937_trainacc_0.9325_testacc_0.9047.flax",
            "2_4_rope": "agnews_bert_rope_seed0_opt_adam_lr_0.001_L2_H4_epoch4_trainloss_0.2180_testloss_0.2869_trainacc_0.9225_testacc_0.9052.flax",
            "2_8_rope": "agnews_bert_rope_seed0_opt_adam_lr_0.001_L2_H8_epoch3_trainloss_0.2522_testloss_0.2845_trainacc_0.9107_testacc_0.9035.flax",
            "6_4_rope": "",
            "6_8_rope": "",
        }
        
        # Note: The rope_str for pretrained models is not included in the provided dictionary.
        # This assumes non-rope pretrained models.
        pretrained_ckpt_path = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}', pretrained_ckpt_dict[f"{num_layers}_{num_heads}"])

        for i in range(1, 4):
            command = (f'python -m src.agnews.agnews_bert_finetune_attn {rope_flag} --seed {i} --optimizer adam '
                       f'--learning-rate 0.001 --num-layers {num_layers} --num-heads {num_heads} '
                       f'--finetune-layer-which {finetune_layer_which} --model-path {pretrained_ckpt_path} '
                       f'--ckpt-path {directory} --data-path {data_dir}')
            print(f"Running command: {command}")
            os.system(command)

    elif task == 'matching':
        all_head_permu = True
        all_head_permu_str = "_all_heads_permu" if all_head_permu else ""
        ckpt_path = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}')
        ckpt_dict = {}
        for filename in os.listdir(ckpt_path):
            if 'agnews_bert_attn_finetune_seed' in filename:
                seed = filename.split('seed')[1].split('_')[0]
                ckpt_dict[seed] = filename

        seeds = list(range(1, 4))

        for i in range(len(seeds)):
            for j in range(i + 1, len(seeds)):
                plot_path = os.path.join(log_base_dir, f'plots{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}', f'{seeds[i]}{seeds[j]}{all_head_permu_str}')
                if os.path.exists(plot_path): shutil.rmtree(plot_path)
                os.makedirs(plot_path, exist_ok=True)
                
                model_a_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[i])])
                model_b_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[j])])

                command = (f'python -m src.agnews.agnews_bert_matching_attn{all_head_permu_str} {rope_flag} --num-layers {num_layers} '
                           f'--num-heads {num_heads} --finetune-layer-which {finetune_layer_which} '
                           f'--model-a "{model_a_path}" --model-b "{model_b_path}" '
                           f'--data-path {data_dir} --plot-path {plot_path}')
                print(f"Running command: {command}")
                os.system(command)

if __name__ == "__main__":
    run_task()