import os
import shutil
import torch
import jax
import argparse

def run_task():
    """
    This script runs different tasks related to a Vision Transformer model on the CIFAR-10 dataset.
    You can configure the task and its parameters by modifying the variables inside this function.
    """
    # --- Environment Setup ---
    # Check for CUDA availability and print GPU details
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    if cuda_available:
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.current_device()}")
        print(f"GPU Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

    print("JAX running on:", jax.devices())

    # --- Task and Model Parameters ---
    # flag: 0 for train, 1 for finetune, 2 for matching
    flag = 2
    dataset = "cifar100"
    log_base_dir = f"/root/log/imgnet{dataset}"
    finetune_layer_which = "0,1"
    finetune_layer_which = "6,7,8,9,10,11"#"all"  # e.g., 0, 1, ... or 'all' depending on the finetune script
    
    rope_use = False  # Whether to use Rotary Position Embedding
    rope_str = "_rope" if rope_use else ""
    rope_flag = "--rope-use" if rope_use else ""

    # --- Task Execution ---
    tasks = ['train', 'finetune', 'matching']
    task = tasks[flag]
    os.chdir("/root/src/vision_transformer")

    if task == 'train':
        # This task would involve training a ViT model from scratch on CIFAR-10, to prepare a pretrained ckpt
        print("Training from scratch...")
        command = f'python -m imgnetcifar.train_imgnetcifar'
        os.system(command)

    elif task == 'finetune':
        print("Finetuning a pre-trained ViT model...")
        # Path to the pre-trained model checkpoint
        if dataset == "cifar10":
            pretrained_ckpt = "/root/log/imgnetcifar10/pretrain12/imgnetcifar10_vit_epoch15_trainloss_0.3907_testloss_0.0579_trainacc_0.8656_testacc_0.9823.npz"
        elif dataset == "cifar100":
            pretrained_ckpt = "/root/log/imgnetcifar100/pretrain12/imgnetcifar100_vit_epoch8_trainloss_1.3956_testloss_0.8260_trainacc_0.7387_testacc_0.9030.npz"

        directory = f"{log_base_dir}/ckpts_{finetune_layer_which}{rope_str}"
        if os.path.exists(directory): shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
        for i in range(1,4):
            command = (f'python -m imgnetcifar.finetune_imgnetcifar_attn '
                       f'--dataset {dataset} '
                       f'--finetune-layer-which {finetune_layer_which} '
                       f'--seed {i} '
                       f'--learning-rate 0.1 ' #0.0005
                       f'--ckpt-path {directory} '
                       f'--model-path {pretrained_ckpt} ' 
                       f'--num-epochs 20')
            print(f"Running command: {command}")
            os.system(command)


    elif task == 'matching':
        ckpt_path = os.path.join(log_base_dir, f'ckpts_{finetune_layer_which}{rope_str}')
        ckpt_dict = {}
        for filename in os.listdir(ckpt_path):
            if 'seed' in filename:
                seed = filename.split('seed')[1].split('_')[1][0]
                ckpt_dict[seed] = filename

        seeds = [1, 2, 3]

        for i in range(len(seeds)):
            for j in range(i + 1, len(seeds)):
                plot_path = os.path.join(log_base_dir, f'plots_{finetune_layer_which}{rope_str}', f'{seeds[i]}{seeds[j]}')
                #if os.path.exists(plot_path): shutil.rmtree(plot_path)
                os.makedirs(plot_path, exist_ok=True)
                
                model_a_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[i])])
                model_b_path = os.path.join(ckpt_path, ckpt_dict[str(seeds[j])])
                
                command = f'python -m imgnetcifar.matching_imgnetcifar_attn {rope_flag} --dataset {dataset} --finetune-layer-which {finetune_layer_which} --model-a "{model_a_path}" --model-b "{model_b_path}" --plot-path {plot_path}'
                print(f"Running command: {command}")
                os.system(command)


if __name__ == "__main__":
    run_task()