import random
import os
import torch
import argparse
import yaml
import wandb
import random
import copy
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
from torch.utils.data import DataLoader, DistributedSampler, Subset
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.amp import GradScaler, autocast
from timm import create_model
from tqdm import tqdm
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from transformers import set_seed
from timm.data import resolve_model_data_config, create_transform


def load_config(args: argparse.Namespace) -> dict:
    """Load the configuration file."""
    with open(args.config, "r") as f:
        config = yaml.safe_load(f)
    if args.learning_rate_task1:
        config['learning_rate_task1'] = args.learning_rate_task1
        print(
            f"[INFO] Overriding learning rate for task1 to: {config['learning_rate_task1']}")
    if args.learning_rate_task2:
        config['learning_rate_task2'] = args.learning_rate_task2
        # Override the learning rate if passed as an argument
        print(
            f"[INFO] Overriding learning rate for task2 to: {config['learning_rate_task2']}")
    if 'wandb_name' in config:
        config['wandb_name'] += f"_lr_task1_{config['learning_rate_task1']}_lr_task2_{config['learning_rate_task2']}"
        print(f"[INFO] Updated wandb_name to: {config['wandb_name']}")
    # Get SLURM job ID if available and add it to the config
    slurm_job_id = os.environ.get("SLURM_JOB_ID", None)
    if slurm_job_id:
        config['slurm_job_id'] = slurm_job_id
    return config


def setup_save_dir(configs) -> str:
    """
    Add a timestamp to save_dir and ensure the directory exists.
    Args:
        configs (dict): Configuration dictionary.
    Returns:
        str: The updated save_dir path.
    """
    # Get save_dir from configs or use a default value
    model_save_dir = configs.get('model_save_dir', './forgetting_checkpoints')
    os.makedirs(model_save_dir, exist_ok=True)
    # Add a timestamp to the save_dir
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")  # Format: YYYYMMDD_HHMMSS
    model_name = configs.get('model_name', 'model')
    optimizer_name = configs.get('optimizer_name', 'optimizer')
    lr_task1 = configs.get('learning_rate_task1', 'lr1')
    lr_task2 = configs.get('learning_rate_task2', 'lr2')
    timestamp = f"{timestamp}_{model_name}_{optimizer_name}_lr1_{lr_task1}_lr2_{lr_task2}"
    model_save_dir = os.path.join(model_save_dir, timestamp)

    # Ensure the directory exists
    os.makedirs(model_save_dir, exist_ok=True)
    print(f"[INFO] Save directory set to: {model_save_dir}")

    # Update the configs with the new save_dir
    configs['model_save_dir'] = model_save_dir
    return model_save_dir


def init_distributed_mode() -> tuple:
    """Initialize distributed training mode."""
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    else:
        rank, world_size, local_rank = 0, 1, 0
        torch.cuda.set_device(local_rank)
    return rank, world_size, local_rank


def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """Reduce a tensor across all processes."""
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    return tensor / dist.get_world_size()


def setup_optimizer_and_scheduler(model, configs, task, task_train_size):
    """
    Set up the optimizer and scheduler for a given task.

    Args:
        model (torch.nn.Module): The model to optimize.
        configs (dict): Configuration dictionary.
        task (str): Task name ("task1" or "task2").

    Returns:
        optimizer, scheduler: Configured optimizer and scheduler.
    """
    print("[INFO] Setting up loss function and optimizer...")
    criterion = nn.CrossEntropyLoss(
        label_smoothing=configs['label_smoothing_task1'] if task == "task1" else configs['label_smoothing_task2'])

    learning_rate = configs[f'learning_rate_{task}']
    optimizer_name = configs['optimizer_name'].lower()
    weight_decay = configs['weight_decay']
    total_epoch = configs['num_epochs_task1'] if task == "task1" else configs['num_epochs_task2']
    scheduler_name = configs['scheduler_name_task1'] if task == "task1" else configs['scheduler_name_task2']

    if optimizer_name == "adamw":
        print(f"[INFO] Using AdamW optimizer for {task}...")
        from torch.optim import AdamW
        optimizer = AdamW(model.parameters(), lr=learning_rate,
                          weight_decay=weight_decay)
    elif optimizer_name == "nanoadam":
        print(f"[INFO] Using NanoAdam optimizer for {task}...")
        from Nanoadam import NanoAdam
        param_groups = param_groups_with_name(model)
        total_gpus = int(os.environ.get("WORLD_SIZE", 1))
        # Calculate total steps based on the dataset size and batch size
        total_steps = task_train_size / \
            (configs['batch_size'] * total_gpus) * total_epoch
        optimizer = NanoAdam(
            param_groups,
            lr=learning_rate,
            k_init=configs['k_init'],
            largest=configs['largest'],
            betas=(configs['beta1'], configs['beta2']),
            weight_decay=weight_decay,
            eps=configs['eps'],
            log_every=configs['log_every'],
            total_steps=total_steps,
            mask_interval=configs['mask_interval'],
            dynamic_density=configs['dynamic_density'],
            density_interval=configs['density_interval'],
            exclude_layers=set(configs['exclude_layers']),
            mask_criterion=configs['mask_criterion'],
        )
    elif optimizer_name == "microadam":
        print(f"[INFO] Using MicroAdam optimizer for {task}...")
        from microadam import MicroAdam
        param_groups = param_groups_with_name(model)
        optimizer = MicroAdam(
            param_groups,
            m=configs['NGRADS'],
            lr=learning_rate,
            quant_block_size=configs['QUANT_BLOCK_SIZE'],
            k_init=configs['k_init'],
            betas=(configs['beta1'], configs['beta2']),
            weight_decay=weight_decay,
            eps=configs['eps'],
            log_every=configs['log_every'],
        )
    elif optimizer_name == "adamw8b":
        print(f"[INFO] Using 8bit-AdamW optimizer for {task}...")
        import bitsandbytes as bnb
        param_groups = param_groups_weight_decay(
            model, weight_decay, excluded_layers=[])
        optimizer = bnb.optim.AdamW8bit(
            param_groups,
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=(configs['beta1'], configs['beta2']),
            eps=configs['eps'],
            optim_bits=8,
        )
    
    else:
        raise ValueError(f"Unsupported optimizer: {optimizer_name}")

    # Scheduler setup
    if scheduler_name == "cosineannealinglr":
        print(f"[INFO] Using CosineAnnealingLR scheduler for {task}...")
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=total_epoch)
    elif scheduler_name == "steplr":
        print(f"[INFO] Using StepLR scheduler for {task}...")
        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=5, gamma=0.1)
    else:
        scheduler = None

    return criterion, optimizer, scheduler


def is_excluded_layer(layer_name, excluded_layers):
    for ex in excluded_layers:
        if ex in layer_name:
            return True
    return False


def param_groups_weight_decay(
    model: nn.Module,
    weight_decay=1e-5,
    excluded_layers=[],
    no_weight_decay_list=(),
):
    """
    This method is copied from timm.optim.optim_factory.param_groups_weight_decay
    What's new:
        - excluded_layers parameter
        - the additional if statement that calls method is_excluded_layer
    """
    no_weight_decay_list = set(no_weight_decay_list)
    decay = []
    no_decay = []
    size_layers_kept = 0
    size_layers_ignored = 0
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if is_excluded_layer(name, excluded_layers):
            print(f"Excluding layer {name} from optimizer")
            size_layers_ignored += param.numel()
            continue

        size_layers_kept += param.numel()

        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
            no_decay.append(param)
        else:
            decay.append(param)

    print(
        f"\n\tPARAMETER GROUPS:\n\t\t{size_layers_kept=}\n\t\t{size_layers_ignored=}\n"
    )

    return [
        {"params": no_decay, "weight_decay": 0.0},
        {"params": decay, "weight_decay": weight_decay},
    ]


def param_groups_with_name(
    model: nn.Module,
):
    params = []
    names = []
    for name, param in model.named_parameters():
        params.append(param)
        names.append(name)
    return [{"params": params,  "names": names}]


def train_per_epoch(
        cur_epoch, total_epoch, model, train_loader, train_sampler, optimizer, scheduler, criterion, distributed, configs):
    if distributed:
        train_sampler.set_epoch(cur_epoch)
    model.train()
    running_loss, total_correct, total_samples = 0.0, 0, 0
    # bfloat16 does not need GradScaler, but kept for compatibility
    scaler = GradScaler(enabled=configs['fp16'])

    for images, labels in tqdm(train_loader, desc=f"Epoch {cur_epoch+1}/{total_epoch}"):
        images, labels = images.cuda(), labels.cuda()
        optimizer.zero_grad()
        with autocast(device_type='cuda', enabled=configs['bf16'] or configs['fp16'], dtype=torch.bfloat16 if configs['bf16'] else torch.float16 if configs['fp16'] else torch.float32):
            outputs = model(images)
            loss = criterion(outputs, labels)

        if configs['fp16']:
            # Use GradScaler for float16
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # Direct backward pass for bfloat16 or full precision
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * images.size(0)
        total_correct += (outputs.argmax(dim=1) == labels).sum().item()
        total_samples += images.size(0)
    if scheduler is not None:
        scheduler.step()

    return running_loss, total_correct, total_samples


def eval_per_epoch(
        model, test_loader, criterion, configs):
    # Evaluation
    model.eval()
    running_loss, total_correct, total_samples = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            with autocast(device_type='cuda', enabled=configs['bf16'] or configs['fp16'], dtype=torch.bfloat16 if configs['bf16'] else torch.float16 if configs['fp16'] else torch.float32):
                outputs = model(images)
                loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            total_correct += (outputs.argmax(dim=1) == labels).sum().item()
            total_samples += images.size(0)
    return running_loss, total_correct, total_samples


def save_model_and_head(configs, model, task_name, save_dir):
    """
    Save the model and its classification head.
    Args:
        model (torch.nn.Module): The model to save.
        task_name (str): The name of the task (e.g., "task1", "task2").
        save_dir (str): The directory to save the model and head.
    """
    os.makedirs(save_dir, exist_ok=True)
    model_path = os.path.join(save_dir, f"{task_name}_model.pth")
    head_path = os.path.join(save_dir, f"{task_name}_head.pth")

    # Access the underlying model if wrapped in DDP
    model_to_save = model.module if hasattr(model, 'module') else model

    # Save the entire model
    torch.save(model_to_save.state_dict(), model_path)

    # Save the classification head
    if 'vit' in configs['model_name'].lower() and hasattr(model_to_save, 'head'):
        torch.save(model_to_save.head.state_dict(), head_path)
    elif 'resnet' in configs['model_name'].lower() and hasattr(model_to_save, 'fc'):
        torch.save(model_to_save.fc.state_dict(), head_path)
    else:
        raise ValueError(
            "Unsupported model architecture for saving classification head.")

    print(
        f"[INFO] Saved {task_name} model to {model_path} and head to {head_path}")


def load_head(configs, model, head_path, num_classes):
    """
    Load the classification head for the model.
    Args:
        model (torch.nn.Module): The model to load the head into.
        head_path (str): The path to the saved head file.
        num_classes (int): The number of classes for the classification head.
    """
    # Access the underlying model if wrapped in DDP
    model_to_load = model.module if hasattr(model, 'module') else model

    # Load the classification head
    if 'vit' in configs['model_name'].lower() and hasattr(model_to_load, 'head'):
        model_to_load.head = nn.Linear(
            model_to_load.head.in_features, num_classes).cuda()
        model_to_load.head.load_state_dict(torch.load(head_path))
    elif 'resnet' in configs['model_name'].lower() and hasattr(model_to_load, 'fc'):
        model_to_load.fc = nn.Linear(
            model_to_load.fc.in_features, num_classes).cuda()
        model_to_load.fc.load_state_dict(torch.load(head_path))
    else:
        raise ValueError(
            "Unsupported model architecture for loading classification head.")

    print(f"[INFO] Loaded classification head from {head_path}")


def random_init_head(configs, model, num_classes):
    # Access the underlying model if wrapped in DDP
    model_to_update = model.module if hasattr(model, 'module') else model

    # Update the classification head for Task 2
    if 'vit' in configs['model_name'].lower() and hasattr(model_to_update, 'head'):
        model_to_update.head = nn.Linear(
            model_to_update.head.in_features, num_classes).cuda()
    elif 'resnet' in configs['model_name'].lower() and hasattr(model_to_update, 'fc'):
        model_to_update.fc = nn.Linear(
            model_to_update.fc.in_features, num_classes).cuda()
    else:
        raise ValueError(
            "Unsupported model architecture for updating classification head.")


def load_dataset(configs, transform_train, transform_eval, local_rank, distributed):
    print("[INFO] Loading datasets...")
    datasets = {"cifar10": torchvision.datasets.CIFAR10,
                "flowers": torchvision.datasets.Flowers102}

    if configs['task1_dataset'] == "cifar10" and configs['task2_dataset'] == "flowers":
        print("[INFO] Setting up data transforms...")

        if local_rank == 0:
            # Only rank 0 downloads the datasets
            task1_train_dataset = datasets[configs['task1_dataset']](
                root=configs['data1_root'], train=True, transform=transform_train, download=True)
            task1_test_dataset = datasets[configs['task1_dataset']](
                root=configs['data1_root'], train=False, transform=transform_eval, download=True)
            task2_train_dataset = datasets[configs['task2_dataset']](
                root=configs['data2_root'], split='train', transform=transform_train, download=True)
            task2_test_dataset = datasets[configs['task2_dataset']](
                root=configs['data2_root'], split='test', transform=transform_eval, download=True)

        # Synchronize all processes to ensure datasets are downloaded
        if distributed:
            dist.barrier()

        # Load datasets for all ranks
        task1_train_dataset = datasets[configs['task1_dataset']](
            root=configs['data1_root'], train=True, transform=transform_train, download=False)
        task1_test_dataset = datasets[configs['task1_dataset']](
            root=configs['data1_root'], train=False, transform=transform_eval, download=False)
        task2_train_dataset = datasets[configs['task2_dataset']](
            root=configs['data2_root'], split='train', transform=transform_train, download=False)
        task2_test_dataset = datasets[configs['task2_dataset']](
            root=configs['data2_root'], split='test', transform=transform_eval, download=False)

    elif configs['task1_dataset'] == "flowers" and configs['task2_dataset'] == "cifar10":
        if local_rank == 0:
            # Only rank 0 downloads the datasets
            task1_train_dataset = datasets[configs['task1_dataset']](
                root=configs['data1_root'], split='train', transform=transform_train, download=True)
            task1_test_dataset = datasets[configs['task1_dataset']](
                root=configs['data1_root'], split='test', transform=transform_eval, download=True)
            task2_train_dataset = datasets[configs['task2_dataset']](
                root=configs['data2_root'], train=True, transform=transform_train, download=True)
            task2_test_dataset = datasets[configs['task2_dataset']](
                root=configs['data2_root'], train=False, transform=transform_eval, download=True)

        # Synchronize all processes to ensure datasets are downloaded
        if distributed:
            dist.barrier()

        # Load datasets for all ranks
        task1_train_dataset = datasets[configs['task1_dataset']](
            root=configs['data1_root'], split='train', transform=transform_train, download=False)
        task1_test_dataset = datasets[configs['task1_dataset']](
            root=configs['data1_root'], split='test', transform=transform_eval, download=False)
        task2_train_dataset = datasets[configs['task2_dataset']](
            root=configs['data2_root'], train=True, transform=transform_train, download=False)
        task2_test_dataset = datasets[configs['task2_dataset']](
            root=configs['data2_root'], train=False, transform=transform_eval, download=False)
    else:
        # Unsupported dataset combination
        raise ValueError(
            "Unsupported dataset combination. Please use CIFAR10 and Flowers102.")

   
    return task1_train_dataset, task1_test_dataset, task2_train_dataset, task2_test_dataset


def task(rank, world_size, local_rank, configs):
    distributed = configs['ddp']
    set_seed(configs['seed'])
    print(f"[INFO] Using seed: {configs['seed']}")
    if local_rank == 0:
        print("[INFO] Initializing Weights & Biases...")
        wandb.init(project=configs['wandb_project'],
                   name=configs['wandb_name'], config=configs)

    # Load pretrained model
    print("[INFO] Loading pretrained model...")
    if 'vit' in configs['model_name'].lower() or 'resnet' in configs['model_name'].lower():
        model = create_model(configs['model_name'], pretrained=configs['pretrained'],
                             num_classes=configs['task1_num_classes']).cuda()
        data_config = resolve_model_data_config(model)
        transform_train = create_transform(**data_config, is_training=True)
        transform_eval = create_transform(**data_config, is_training=False)
    else:
        raise ValueError(
            f"Unsupported model architecture: {configs['model_name']}")

    # Dataset loading
    task1_train_dataset, task1_test_dataset, task2_train_dataset, task2_test_dataset = load_dataset(
        configs, transform_train, transform_eval, local_rank, distributed)

    print("[INFO] Creating data loaders...")
    task1_train_sampler = DistributedSampler(
        task1_train_dataset, seed=configs['seed']) if distributed else None
    task1_train_loader = DataLoader(
        task1_train_dataset, batch_size=configs['batch_size'], sampler=task1_train_sampler)
    task1_test_loader = DataLoader(
        task1_test_dataset, batch_size=configs['batch_size'], shuffle=False)
    task2_train_sampler = DistributedSampler(
        task2_train_dataset, seed=configs['seed']) if distributed else None
    task2_train_loader = DataLoader(
        task2_train_dataset, batch_size=configs['batch_size'], sampler=task2_train_sampler)
    task2_test_loader = DataLoader(
        task2_test_dataset, batch_size=configs['batch_size'], shuffle=False)

    if configs['bf16']:
        print("[INFO] Using bfloat16 precision...")
        model = model.to(dtype=torch.bfloat16)
    elif configs['fp16']:
        print("[INFO] Using float16 precision...")
        model = model.to(dtype=torch.float16)
    model = DDP(model, device_ids=[
                local_rank], find_unused_parameters=True) if distributed else model
    print("[INFO] Model architecture:")
    print(model)

    # Loss and optimizer
    criterion, optimizer, scheduler = setup_optimizer_and_scheduler(
        model, configs, "task1", len(task1_train_dataset))

    # Training loop for Task 1
    print("[INFO] Starting training for Task 1...")
    for epoch in range(configs['num_epochs_task1']):
        print(
            f"[INFO] Task 1 - Epoch {epoch+1}/{configs['num_epochs_task1']}...")
        train_loss, train_correct, train_samples = train_per_epoch(
            epoch, configs['num_epochs_task1'], model, task1_train_loader, task1_train_sampler, optimizer, scheduler, criterion, distributed, configs)
        val_loss, val_correct, val_samples = eval_per_epoch(
            model, task1_test_loader, criterion, configs)

        if distributed:
            syn_tensor = torch.tensor(
                [train_loss, train_correct, train_samples, val_loss,
                    val_correct, val_samples], device='cuda'
            )
            train_loss, train_correct, train_samples, val_loss, val_correct, val_samples = reduce_tensor(
                syn_tensor)
            dist.barrier()

        train_loss = train_loss/train_samples
        train_acc = train_correct/train_samples
        val_loss = val_loss/val_samples
        val_acc = val_correct/val_samples

        if local_rank == 0:
            wandb.log({"loss/task1_train": train_loss, "acc/task1_train": train_acc,
                      "loss/task1_val": val_loss, "acc/task1_val": val_acc,
                       "steps/task1_epoch": epoch,
                       "lr/task1": optimizer.param_groups[0]['lr'],
                       })
            print(
                f"[INFO] Task 1 - Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

    # Save the model and classification head after Task 1
    if local_rank == 0:
        print("[INFO] Saving Task 1 model and head...")
        save_model_and_head(configs, model, "task1", configs['model_save_dir'])

    if distributed:
        print("[INFO] Synchronizing model state across processes...")
        dist.barrier()
    del optimizer, scheduler, criterion
    torch.cuda.empty_cache()

    # Task 2: Flowers102
    print("[INFO] Starting training for Task 2...")
    model_for_task2 = create_model(configs['model_name'], pretrained=configs['pretrained'],
                                   num_classes=configs['task1_num_classes']).cuda()
    model_for_task2 = DDP(model_for_task2, device_ids=[
        local_rank], find_unused_parameters=True) if distributed else model
    model_to_load = model_for_task2.module if hasattr(
        model_for_task2, 'module') else model_for_task2
    model_to_load.load_state_dict(torch.load(
        os.path.join(configs['model_save_dir'], "task1_model.pth"),
    ))
    if distributed:
        print("[INFO] Synchronizing model state across processes...")
        dist.barrier()
    random_init_head(configs, model_for_task2, configs['task2_num_classes'])
    if configs['bf16']:
        print("[INFO] Using bfloat16 precision...")
        model_for_task2 = model_for_task2.to(dtype=torch.bfloat16)
    elif configs['fp16']:
        print("[INFO] Using float16 precision...")
        model_for_task2 = model_for_task2.to(dtype=torch.float16)

    # Loss and optimizer
    criterion, optimizer, scheduler = setup_optimizer_and_scheduler(
        model_for_task2, configs, "task2", len(task2_train_dataset))

    for epoch in range(configs['num_epochs_task2']):
        print(
            f"[INFO] Task 2 - Epoch {epoch+1}/{configs['num_epochs_task2']}...")
        train_loss, train_correct, train_samples = train_per_epoch(
            epoch, configs['num_epochs_task2'], model_for_task2, task2_train_loader, task2_train_sampler, optimizer, scheduler, criterion, distributed, configs)
        val_loss, val_correct, val_samples = eval_per_epoch(
            model_for_task2, task2_test_loader, criterion, configs)

        if distributed:
            syn_tensor = torch.tensor(
                [train_loss, train_correct, train_samples, val_loss,
                    val_correct, val_samples], device='cuda'
            )
            train_loss, train_correct, train_samples, val_loss, val_correct, val_samples = reduce_tensor(
                syn_tensor)
            dist.barrier()
        train_loss = train_loss/train_samples
        train_acc = train_correct/train_samples
        val_loss = val_loss/val_samples
        val_acc = val_correct/val_samples

        if local_rank == 0:
            wandb.log({"loss/task2_train": train_loss, "acc/task2_train": train_acc,
                      "loss/task2_val": val_loss, "acc/task2_val": val_acc,
                       "steps/task2_epoch": epoch,
                       "lr/task1": optimizer.param_groups[0]['lr'],
                       })
            print(
                f"[INFO] Task 2 - Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

        # Periodically evaluate on Task 1
        if epoch % configs['eval_task1_interval'] == 0 or epoch == configs['num_epochs_task2'] - 1:
            print(
                f"[INFO] Evaluating Task 1 during Task 2 training at epoch {epoch+1}...")
            print("[INFO] saving Task 2 classification model and head...")
            if local_rank == 0:
                # Save the model and classification head
                save_model_and_head(configs, model_for_task2, "task2",
                                    configs['model_save_dir'])
            if distributed:
                print("[INFO] Synchronizing model state across processes...")
                dist.barrier()
            print("[INFO] Restoring Task 1 classification head...")
            model = copy.deepcopy(model_for_task2)
            load_head(configs, model, os.path.join(
                configs['model_save_dir'], "task1_head.pth"), configs['task1_num_classes'])

            task1_val_loss, task1_val_correct, task1_val_samples = eval_per_epoch(
                model, task1_test_loader, criterion, configs)

            if distributed:
                syn_tensor = torch.tensor(
                    [task1_val_loss, task1_val_correct,
                        task1_val_samples], device='cuda'
                )
                task1_val_loss, task1_val_correct, task1_val_samples = reduce_tensor(
                    syn_tensor)

            task1_val_loss = task1_val_loss/task1_val_samples
            task1_val_acc = task1_val_correct/task1_val_samples

            if local_rank == 0:
                wandb.log({
                    "loss/task1_val_during_FT_task2": task1_val_loss, "acc/task1_val_during_FT_task2": task1_val_acc,
                })
                print(
                    f"[INFO] Task 1 During Task 2 - Epoch {epoch+1}: Val Loss={task1_val_loss:.4f}, Val Acc={task1_val_acc*100:.2f}%")


    if local_rank == 0:
        print("[INFO] Finishing training and closing Weights & Biases...")
        save_model_and_head(configs, model_for_task2, "task2",
                            configs['model_save_dir'])
        wandb.finish()
    if distributed:
        print("[INFO] Synchronizing model state across processes...")
        dist.barrier()
        dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, default="huggingface_glue_mnli/yaml/finetune_vit_cifar10.yaml",
                        help="Path to the config YAML file")
    parser.add_argument('--learning_rate_task1', type=float, required=False,
                        default=None, help="Override learning rate")
    parser.add_argument('--learning_rate_task2', type=float, required=False,
                        default=None, help="Override learning rate")
    args = parser.parse_args()

    # Load the config from the passed path
    configs = load_config(args)
    # Add timestamp to save_dir and ensure it exists
    setup_save_dir(configs)

    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        configs['ddp'] = True
    else:
        configs['ddp'] = False

    # Now you can use the config in your training code
    print(f"Using config: {configs}")

    distributed = configs['ddp']
    if distributed:
        rank, world_size, local_rank = init_distributed_mode()
    else:
        rank, world_size, local_rank = 0, 1, 0

    # adjust batch size for distributed training
    configs['batch_size'] = max(
        1, int(configs['batch_size'] / world_size))

    task(rank, world_size, local_rank, configs)


if __name__ == "__main__":
    main()
