import os
import shutil
import torch
import jax

def run_task():
    """
    This script runs different tasks related to a Vision Transformer model on the CIFAR-10 dataset.
    You can configure the task and its parameters by modifying the variables inside this function.
    """

    # --- Configuration ---
    # Set the root directory for logs and checkpoints.
    # Modify this to a suitable path on your local machine.
    log_base_dir = "/root/log/cifar100"

    # --- Task and Model Parameters ---
    # flag: 0 for train, 1 for finetune, 2 for matching, 3 for analyze_norms
    flag = 1
    num_layers = 6
    num_heads = 4
    finetune_layer_which = "all"
    rope_use = False  # Whether to use Rotary Position Embedding
    rope_str = "_rope" if rope_use else ""
    rope_flag = "--rope-use" if rope_use else ""

    # --- Environment Setup ---
    # Ensure the base directory for logs exists
    os.makedirs(log_base_dir, exist_ok=True)

    # Check for CUDA availability and print GPU details
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    if cuda_available:
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    
    print("JAX running on:", jax.devices())

    # --- Task Execution ---
    tasks = ['train', 'finetune', 'matching', 'analyze_norms']
    task = tasks[flag]

    if task == 'train':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        command = f'python -m src.cifar100.cifar100_vit_train{rope_str} --seed 0 --optimizer adam --learning-rate 0.005 --num-layers {num_layers} --num-heads {num_heads} --ckpt-path {directory}'
        print(f"Running command: {command}")
        os.system(command)

    elif task == 'finetune':
        directory = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_finetune_{finetune_layer_which}{rope_str}')
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        '''pretrained_ckpt_dict = {
            "6_4": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch65_trainloss_1.1133_testloss_2.6792_trainacc_0.6822_testacc_0.4722",
            "6_8": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch49_trainloss_1.2604_testloss_2.4102_trainacc_0.6437_testacc_0.4898",
            "6_4_rope": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch36_trainloss_1.8417_testloss_2.6795_trainacc_0.5027_testacc_0.3849",
            "6_8_rope": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch44_trainloss_1.6400_testloss_2.7292_trainacc_0.5492_testacc_0.3886",
        }'''
        pretrained_ckpt_dict = {
            "6_4": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch87_trainloss_1.7795_testloss_2.3273_trainacc_0.5208_testacc_0.4525",
            "6_8": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch49_trainloss_1.2604_testloss_2.4102_trainacc_0.6437_testacc_0.4898",
            "6_4_rope": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_4_hidden_dim_512_embedding_dim_128_epoch83_trainloss_1.7859_testloss_2.2446_trainacc_0.5169_testacc_0.4540",
            "6_8_rope": "cifar100_vit_seed0_opt_adam_lr_0.005_num_layers_6_patch_size_4_num_heads_8_hidden_dim_512_embedding_dim_128_epoch44_trainloss_1.6400_testloss_2.7292_trainacc_0.5492_testacc_0.3886",
        }
        
        pretrained_ckpt_path = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}{rope_str}', pretrained_ckpt_dict[f"{num_layers}_{num_heads}{rope_str}"])

        for i in range(1, 4):
            command = f'python -m src.cifar100.cifar100_vit_finetune_attn {rope_flag} --seed {i} --optimizer adam --learning-rate 0.003 --num-layers {num_layers} --num-heads {num_heads} --finetune-layer-which {finetune_layer_which} --ckpt-path {directory} --model-path {pretrained_ckpt_path}'
            print(f"Running command: {command}")
            os.system(command)

    elif task == 'matching':
        all_head_permu = True
        all_head_permu_str = "_all_heads_permu" if all_head_permu else ""
        ckpt_path = os.path.join(log_base_dir, f'ckpts{num_layers}_{num_heads}_finetune_{finetune_layer_which}{rope_str}')
        ckpt_dict = {}
        for filename in os.listdir(ckpt_path):
            if 'cifar100_vit_attn_finetune_seed' in filename:
                seed = filename.split('seed')[1].split('_')[0]
                ckpt_dict[seed] = filename

        seeds = [1, 2, 3]

        for i in range(len(seeds)):
            for j in range(i + 1, len(seeds)):
                plot_path = os.path.join(log_base_dir, f'plots{num_layers}_{num_heads}_finetune_{finetune_layer_which}{rope_str}', f'{seeds[i]}{seeds[j]}{all_head_permu_str}')
                if os.path.exists(plot_path): shutil.rmtree(plot_path)
                os.makedirs(plot_path, exist_ok=True)
                
                model_a_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[i])])
                model_b_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[j])])

                command = f'python -m src.cifar100.cifar100_vit_matching_attn{all_head_permu_str} {rope_flag} --num-layers {num_layers} --finetune-layer-which {finetune_layer_which} --num-heads {num_heads} --model-a "{model_a_path}" --model-b "{model_b_path}" --plot-path {plot_path}'
                print(f"Running command: {command}")
                os.system(command)

    elif task == 'analyze_norms':
        directory = os.path.join(log_base_dir, f'pretrain{num_layers}_{num_heads}')
        plot_dir = os.path.join(log_base_dir, 'norms_plot', f'{num_layers}_{num_heads}')
        os.makedirs(plot_dir, exist_ok=True)
        command = f'python -m src.cifar100.analyze_norms --ckpt-path {directory} --plot-path {plot_dir} --num-layers {num_layers} --num-heads {num_heads}'
        print(f"Running command: {command}")
        os.system(command)

if __name__ == "__main__":
    # Before running, make sure you have installed the necessary packages.
    # You might need to run the following commands in your terminal:
    # pip install --upgrade -q "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    # pip install --upgrade -q flax
    # pip install torch
    
    run_task()