import os
import shutil
import torch
import jax

def run_task():
    """
    This script runs training, finetuning, and matching tasks for a BERT-style model on the IMDB Review dataset.
    Modify the configuration variables below to control the execution.
    """

    # --- Configuration ---
    # Set the root directory for logs, checkpoints, and data.
    log_base_dir = "/root/log/imdbreview"
    data_dir = "/root/log/imdbreview/data" # This directory should contain IMDB_Dataset.csv

    # --- Task and Model Parameters ---
    # flag: 0 for train, 1 for finetune, 2 for matching
    flag = 2
    num_layers = 6
    num_heads = 4
    finetune_layer_which = "all"  # e.g., 0, 1, ... or 'all' depending on the finetune script
    rope_use = True  # Whether to use Rotary Position Embedding

    # --- Environment Setup ---
    # Ensure the base directories for logs and data exist
    os.makedirs(log_base_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)
    
    # Define the full path to the dataset file
    dataset_path = os.path.join(data_dir, "IMDB_Dataset.csv")

    # String and flag for handling RoPE variations in filenames and arguments
    rope_str = "_rope" if rope_use else ""
    rope_flag = "--rope-use" if rope_use else "" # Assumes finetune/matching scripts accept this flag

    # Check for CUDA availability and print GPU details
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    if cuda_available:
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

    print("JAX running on:", jax.devices())

    # --- Task Execution ---
    tasks = ['train', 'finetune', 'matching']
    task = tasks[flag]
    print(f"\n--- Running Task: {task.upper()} ---")

    if task == 'train':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)

        # The script name depends on whether RoPE is used
        train_script = f'imdbreview_bert_train{rope_str}'
        command = (f'python -m src.imdbreview.{train_script} '
                   f'--seed 0 --num-layers {num_layers} --num-heads {num_heads} '
                   f'--save-dir {directory} --data-path {dataset_path}')
        print(f"Running command: {command}")
        os.system(command)

    elif task == 'finetune':
        directory = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        
        pretrained_ckpt_dict = {
            "2_4": "imdb_bert_seed0_opt_adam_lr_0.0005_L2_H4_epoch3__testloss0.3055_testacc_0.8749.flax",
            "2_8": "imdb_bert_seed0_opt_adam_lr_0.0005_L2_H8_epoch3_valloss_0.3260_valacc_0.8701.flax",
            "2_6": "imdb_bert_seed0_opt_adam_lr_0.0005_L2_H6_epoch4_testacc_0.8774.flax",
            "6_4": "imdb_bert_seed0_opt_adam_lr_0.0005_L6_H4_epoch3__testloss0.3179_testacc_0.8696.flax",
            "6_8": "imdb_bert_seed0_opt_adam_lr_0.0005_L6_H8_epoch4_testacc_0.8748.flax",
            "2_4_rope": "imdb_bert_rope_seed0_opt_adam_lr_0.0003_L2_H4_epoch4_testacc_0.8734.flax",
            "2_8_rope": "imdb_bert_rope_seed0_opt_adam_lr_0.0003_L2_H8_epoch5_testacc_0.8761.flax",
            "6_4_rope": "imdb_bert_rope_seed0_opt_adam_lr_0.0003_L6_H4_epoch5_testacc_0.8699.flax",
            "6_8_rope": "imdb_bert_rope_seed0_opt_adam_lr_0.0003_L6_H8_epoch5_testacc_0.8730.flax",
        }
        
        pretrain_dir = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        ckpt_filename = pretrained_ckpt_dict[f"{num_layers}_{num_heads}{rope_str}"]
        pretrained_ckpt_path = os.path.join(pretrain_dir, ckpt_filename)

        for i in range(1, 4): # Finetune with 3 different seeds
            command = (f'python -m src.imdbreview.imdbreview_bert_finetune_attn {rope_flag} '
                       f'--seed {i} --optimizer adam --learning-rate 0.0005 '
                       f'--num-layers {num_layers} --num-heads {num_heads} '
                       f'--finetune-layer-which {finetune_layer_which} '
                       f'--model-path {pretrained_ckpt_path} --ckpt-path {directory} '
                       f'--data-path {dataset_path}')
            print(f"Running command: {command}")
            os.system(command)

    elif task == 'matching':
        all_head_permu = False
        all_head_permu_str = "_all_heads_permu" if all_head_permu else ""
        ckpt_path = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}')
        ckpt_dict = {}
        # NOTE: Filename search pattern updated to 'imdb_bert'
        for filename in os.listdir(ckpt_path):
            if 'imdb_bert_attn_finetune_seed' in filename:
                seed = filename.split('seed')[1].split('_')[0]
                ckpt_dict[seed] = filename
        
        seeds = list(range(1, 4))
        print(f"Found finetuned models for seeds: {list(ckpt_dict.keys())}")

        for i in range(len(seeds)):
            for j in range(i + 1, len(seeds)):
                seed_a, seed_b = seeds[i], seeds[j]
                plot_path = os.path.join(log_base_dir, f'plots{num_layers}_{num_heads}_{finetune_layer_which}{rope_str}', f'{seeds[i]}{seeds[j]}{all_head_permu_str}')
                if os.path.exists(plot_path): shutil.rmtree(plot_path)
                os.makedirs(plot_path, exist_ok=True)
                
                model_a_path = os.path.join(ckpt_path, ckpt_dict[str(seed_a)])
                model_b_path = os.path.join(ckpt_path, ckpt_dict[str(seed_b)])

                command = f'python -m src.imdbreview.imdbreview_bert_matching_attn{all_head_permu_str} {rope_flag} --num-layers {num_layers} --finetune-layer-which {finetune_layer_which} --num-heads {num_heads} --model-a "{model_a_path}" --model-b "{model_b_path}" --plot-path {plot_path} --data-path {dataset_path}'

                print(f"Running command: {command}")
                os.system(command)

if __name__ == "__main__":
    run_task()