import os
import shutil
import torch
import jax

def run_task():
    """
    This script runs different tasks related to a Vision Transformer model on the CIFAR-10 dataset.
    You can configure the task and its parameters by modifying the variables inside this function.
    """

    # --- Configuration ---
    # Set the root directory for logs and checkpoints.
    # Modify this to a suitable path on your local machine.
    log_base_dir = "/root/log/cifar10"

    # --- Task and Model Parameters ---
    # flag: 0 for train, 1 for finetune, 2 for matching, 3 for analyze_norms
    flag = 2
    num_layers = 6
    num_heads = 4
    finetune_layer_which = "all"
    rope_use = False  # Whether to use Rotary Position Embedding
    rope_str = "_rope" if rope_use else ""
    rope_flag = "--rope-use" if rope_use else ""

    # --- Environment Setup ---
    # Ensure the base directory for logs exists
    os.makedirs(log_base_dir, exist_ok=True)

    # Check for CUDA availability and print GPU details
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    if cuda_available:
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    
    print("JAX running on:", jax.devices())

    # --- Task Execution ---
    tasks = ['train', 'finetune', 'matching', 'analyze_norms']
    task = tasks[flag]

    if task == 'train':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        command = f'python -m src.cifar10.cifar10_vit_train{rope_str} --seed 0 --optimizer adam --learning-rate 0.005 --num-layers {num_layers} --num-heads {num_heads} --ckpt-path {directory}'
        print(f"Running command: {command}")
        os.system(command)

    elif task == 'finetune':
        directory = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_finetune_{finetune_layer_which}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        '''pretrained_ckpt_dict = {
            "2_4": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch61_trainloss_0.7172_testloss_0.8477_trainacc_0.7450_testacc_0.7175",
            "2_8": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch59_trainloss_0.6593_testloss_0.7845_trainacc_0.7686_testacc_0.7351",
            "4_4": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch74_trainloss_0.5393_testloss_0.7173_trainacc_0.8098_testacc_0.7720",
            "4_8": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch66_trainloss_0.5450_testloss_0.7242_trainacc_0.8089_testacc_0.7714",
            "6_4": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch83_trainloss_0.4782_testloss_0.7858_trainacc_0.8329_testacc_0.7723",
            "6_8": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch71_trainloss_0.4499_testloss_0.7939_trainacc_0.8439_testacc_0.7718",
            "2_4_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch54_trainloss_0.9160_testloss_1.0635_trainacc_0.6719_testacc_0.6339",
            "2_8_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch87_trainloss_0.7666_testloss_1.0938_trainacc_0.7268_testacc_0.6518",
            "4_4_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch97_trainloss_0.5922_testloss_1.0133_trainacc_0.7904_testacc_0.6919",
            "4_8_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch99_trainloss_0.5965_testloss_1.0345_trainacc_0.7905_testacc_0.6814",
            "6_4_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch64_trainloss_0.6844_testloss_0.9516_trainacc_0.7589_testacc_0.6873",
            "6_8_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch68_trainloss_0.6206_testloss_0.9770_trainacc_0.7803_testacc_0.6901",
        }'''
        pretrained_ckpt_dict = {
            "2_4": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch96_trainloss_0.9038_testloss_0.9211_trainacc_0.6793_testacc_0.6829",
            "2_8": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch59_trainloss_0.6593_testloss_0.7845_trainacc_0.7686_testacc_0.7351",
            "4_4": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch74_trainloss_0.5393_testloss_0.7173_trainacc_0.8098_testacc_0.7720",
            "4_8": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch66_trainloss_0.5450_testloss_0.7242_trainacc_0.8089_testacc_0.7714",
            "6_4": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch81_trainloss_0.7400_testloss_0.8410_trainacc_0.7442_testacc_0.7284",
            "6_8": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch71_trainloss_0.4499_testloss_0.7939_trainacc_0.8439_testacc_0.7718",
            "2_4_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch65_trainloss_0.8995_testloss_1.0527_trainacc_0.6786_testacc_0.6394",
            "2_8_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_2_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch87_trainloss_0.7666_testloss_1.0938_trainacc_0.7268_testacc_0.6518",
            "4_4_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch97_trainloss_0.5922_testloss_1.0133_trainacc_0.7904_testacc_0.6919",
            "4_8_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_4_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch99_trainloss_0.5965_testloss_1.0345_trainacc_0.7905_testacc_0.6814",
            "6_4_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch83_trainloss_0.5986_testloss_0.9501_trainacc_0.7905_testacc_0.6982",
            "6_8_rope": "cifar10_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch68_trainloss_0.6206_testloss_0.9770_trainacc_0.7803_testacc_0.6901",
        }


        pretrained_ckpt_path = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}', pretrained_ckpt_dict[f"{num_layers}_{num_heads}{rope_str}"])

        for i in range(1, 4):
            command = f'python -m src.cifar10.cifar10_vit_finetune_attn {rope_flag} --seed {i} --optimizer adam --learning-rate 0.005 --num-layers {num_layers} --num-heads {num_heads} --finetune-layer-which {finetune_layer_which} --ckpt-path {directory} --model-path {pretrained_ckpt_path}'
            print(f"Running command: {command}")
            os.system(command)

    elif task == 'matching':
        all_head_permu = False
        all_head_permu_str = "_all_heads_permu" if all_head_permu else ""
        ckpt_path = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_finetune_{finetune_layer_which}{rope_str}')
        ckpt_dict = {}
        for filename in os.listdir(ckpt_path):
            if 'cifar10_vit_attn_finetune_seed' in filename:
                seed = filename.split('seed')[1].split('_')[0]
                ckpt_dict[seed] = filename

        seeds = [1, 2, 3]

        for i in range(len(seeds)):
            for j in range(i + 1, len(seeds)):
                plot_path = os.path.join(log_base_dir, f'plots{num_layers}_{num_heads}_finetune_{finetune_layer_which}{rope_str}', f'{seeds[i]}{seeds[j]}{all_head_permu_str}')
                #if os.path.exists(plot_path): shutil.rmtree(plot_path)
                os.makedirs(plot_path, exist_ok=True)
                
                model_a_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[i])])
                model_b_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[j])])

                command = f'python -m src.cifar10.cifar10_vit_matching_attn{all_head_permu_str} {rope_flag} --num-layers {num_layers} --finetune-layer-which {finetune_layer_which} --num-heads {num_heads} --model-a "{model_a_path}" --model-b "{model_b_path}" --plot-path {plot_path}'
                print(f"Running command: {command}")
                os.system(command)

    elif task == 'analyze_norms':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}')
        plot_dir = os.path.join(log_base_dir, 'norms_plot', f'{num_layers}_{num_heads}')
        os.makedirs(plot_dir, exist_ok=True)
        command = f'python -m src.cifar10.analyze_norms --ckpt-path {directory} --plot-path {plot_dir} --num-layers {num_layers} --num-heads {num_heads}'
        print(f"Running command: {command}")
        os.system(command)

if __name__ == "__main__":
    # Before running, make sure you have installed the necessary packages.
    # You might need to run the following commands in your terminal:
    # pip install --upgrade -q "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    # pip install --upgrade -q flax
    # pip install torch
    
    run_task()