import argparse
import datetime
import json
import os
import time
import warnings
from functools import partial
from pathlib import Path
from contextlib import nullcontext
import re

import yaml
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torchvision import transforms

import fourm.utils
from fourm.models import fm_vit
from fourm.models.output_heads import ConvNeXtHead
from fourm.data.modality_transforms import UnifiedDataTransform
from fourm.models.lora_utils import (
    inject_trainable_LoRA,
    unfreeze_all_LoRA_layers,
    get_LoRA_module_names
)
from fourm.utils import (
    NativeScalerWithGradNormCount as NativeScaler,
    create_model,
    load_safetensors,
    collect_results_cpu,
    LayerDecayValueAssigner,
    create_optimizer,
    interpolate_pos_embed_vit,
    interpolate_pos_embed_beit,
    interpolate_rgb_pos_emb_fm,
)
from fourm.data import (
    PreTokenizedImageAugmenter,
    CenterCropImageAugmenter,
    RandomCropImageAugmenter
)
from fourm.data.unified_datasets import build_fm_transfer_dataset

from semseg_metrics import mean_iou, save_metrics
from log_images import log_semseg_wandb, save_semseg_preds
from semseg_transforms import TRANSFER_MODALITY_INFO, TRANSFER_MODALITY_TRANSFORMS

def get_args():
    config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
    parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                        help='YAML config file specifying default arguments')

    parser = argparse.ArgumentParser('Semantic segmentation fine-tuning script', add_help=False)
    parser.add_argument('--run_name', type=str, default='auto')
    
    parser.add_argument('--batch_size', default=4, type=int,
                        help='Batch size per GPU (default: %(default)s). '
                             'Effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
    parser.add_argument('--epochs', default=64, type=int)
    parser.add_argument('--save_ckpt_freq', default=20, type=int)

    # Task parameters
    parser.add_argument('--in_domains', default='rgb@224', type=str,
                        help='Input domain names, separated by hyphen')
    parser.add_argument('--out_domains', default='semseg_thor', type=str,
                        help='Output domain name')
    parser.add_argument('--use_mask_valid', action='store_true')
    parser.add_argument('--no_mask_valid', action='store_false', dest='use_mask_valid')
    parser.set_defaults(use_mask_valid=False)
    parser.add_argument('--crop_augmentation', action='store_true')
    parser.add_argument('--no_crop_augmentation', action='store_false', dest='crop_augmentation')
    parser.set_defaults(crop_augmentation=True)

    # Model parameters
    parser.add_argument('--model', default='multivit_base', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--patch_size', default=16, type=int,
                        help='base patch size for image-like modalities')
    parser.add_argument('--input_size', default=512, type=int,
                        help='images input size for backbone')
    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--encoder_norm', action='store_true',
                        help="Norm layer after encoder blocks (used in 4M pre-training)")
    parser.add_argument('--no_encoder_norm', action='store_false', dest='encoder_norm',
                        help="No norm layer after encoder blocks")
    parser.set_defaults(encoder_norm=True)
    
    parser.add_argument('--freeze_enc', action='store_true',
                        help="Freeze encoder weights")
    parser.add_argument('--no_freeze_enc', action='store_false', dest='freeze_enc',
                        help="No freeze encoder weights")
    parser.set_defaults(freeze_enc=False)
    parser.add_argument('--freeze_embeds', action='store_true',
                        help="Freeze embeddings weights")
    parser.add_argument('--no_freeze_embeds', action='store_false', dest='freeze_embeds',
                        help="No freeze embeddings weights")
    parser.set_defaults(freeze_embeds=False)
    parser.add_argument('--frozen_encoder_epochs', default=0, type=int,
                        help='Number of epochs where only decoder trained (default: %(default)s)')
    parser.add_argument('--frozen_encoder_lr', type=float, default=5e-5,
                        help='Learning rate for frozen encoder (default: %(default)s)')
    
    

    parser.add_argument('--output_head', type=str, default='convnext',
                        choices=['segmenter', 'convnext', 'dpt'],
                        help='One of [segmenter,  convnext, dpt] (default: convnext)')
    parser.add_argument('--decoder_dim', default=6144, type=int,
                        help='Token dimension for the decoder layers, for convnext and segmenter heads')
    parser.add_argument('--decoder_depth', default=4, type=int,
                        help='Depth of decoder (for convnext and segmenter heads')
    parser.add_argument('--drop_path_decoder', type=float, default=0.0, metavar='PCT',
                        help='Drop path rate (default: 0.0)')
    parser.add_argument('--decoder_preds_per_patch', type=int, default=16,
                        help='Predictions per patch for convnext head')
    parser.add_argument('--decoder_interpolate_mode', type=str, default='bilinear',
                        choices=['bilinear', 'nearest'], help='for convnext head')

    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
        weight decay. We use a cosine schedule for WD. 
        (Set the same value with args.weight_decay to keep weight decay no change)""")
    parser.add_argument('--decoder_decay', type=float, default=None,
                        help='decoder weight decay')
    parser.add_argument('--no_lr_scale_list', type=str, default='',
                        help='Weights that should not be affected by layer decay rate, separated by hyphen.')

    parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                        help='learning rate (default: 1e-4)')
    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min_lr', type=float, default=0.0, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (0.0)')
    parser.add_argument('--layer_decay', type=float, default=0.75,
                        help='layer-wise lr decay from ELECTRA')

    parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    
    # LoRA 
    parser.add_argument('--lora', action='store_true')
    parser.add_argument('--no_lora', action='store_false', dest='lora')
    parser.set_defaults(lora=False)
    parser.add_argument('--lora_rank', type=int, default=4,
                        help='LoRA matrix rank (default: %(default)s)')
    parser.add_argument('--lora_scale', type=float, default=1.0,
                        help='LoRA scale (default: %(default)s)')
    parser.add_argument('--lora_modules', type=str, default='attn',
                        help='Modules to perform LoRA on. One of [attn, selfattn, xattn, mlp, all] (default: %(default)s)')
    parser.add_argument('--lora_train_embeds', action='store_true')
    parser.add_argument('--no_lora_train_embeds', action='store_false', dest='lora_train_embeds')
    parser.set_defaults(lora_train_embeds=False)

    # Augmentation parameters
    parser.add_argument('--main_augment_domain', type=str, default='rgb@224',
                        help='Main augment domain (default: rgb@224)')

    # Finetuning parameters
    parser.add_argument('--finetune', default='', help='finetune from checkpoint')

    # Dataset parameters
    parser.add_argument('--num_classes', default=150, type=str, help='number of semantic classes')
    parser.add_argument('--dataset_name', default=None, type=str, help='dataset name for plotting')
    parser.add_argument('--data_path', default=None, type=str, help='dataset path')
    parser.add_argument('--eval_data_path', default=None, type=str,
                        help='dataset path for evaluation')
    parser.add_argument('--test_data_path', default=None, type=str,
                        help='dataset path for testing')
    parser.add_argument('--max_val_images', default=None, type=int,
                        help='maximum number of validation images. (default: None)')
    parser.add_argument('--eval_freq', default=10, type=int, help="frequency of evaluation")
    parser.add_argument('--seg_reduce_zero_label', action='store_true',
                        help='set label 0 to ignore, reduce all other labels by 1')
    parser.add_argument('--seg_use_void_label', action='store_true', help='label border as void instead of ignore')
    parser.add_argument('--save_preds', action='store_true', default=False, help='save predictions')
    
    parser.add_argument('--colorize_preds', action='store_true')
    parser.add_argument('--no_colorize_preds', action='store_false', dest='colorize_preds')
    parser.set_defaults(colorize_preds=True)
    

    # Added
    parser.add_argument('--reset_pos_emb', action='store_true')
    parser.add_argument('--interpolate_pos_emb', action='store_true')
    parser.add_argument('--use_act_checkpoint', action='store_true', 
                        help="Enables activation checkpointing")

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--auto_resume', action='store_true')
    parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
    parser.set_defaults(auto_resume=True)

    parser.add_argument('--save_ckpt', action='store_true')
    parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
    parser.set_defaults(save_ckpt=True)

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true',
                        help='Perform evaluation only')
    parser.add_argument('--test', action='store_true',
                        help='Perform testing only')
    parser.add_argument('--dist_eval', action='store_true', default=False,
                    help='Enabling distributed evaluation')
    parser.add_argument('--no_dist_eval', action='store_false', dest='dist_eval',
                    help='Disabling distributed evaluation')
    parser.set_defaults(dist_eval=False)
    parser.add_argument('--num_workers', default=16, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)
    parser.add_argument('--find_unused_params', action='store_true')
    parser.add_argument('--no_find_unused_params', action='store_false', dest='find_unused_params')
    parser.set_defaults(find_unused_params=True)

    # Dtype
    parser.add_argument('--dtype', type=str, default='float16',
                        choices=['float16', 'bfloat16', 'float32', 'tfloat32', 'bf16', 'fp16', 'fp32', 'tf32'],
                        help='Data type (default: %(default)s')
    
    # Misc.
    parser.add_argument('--s3_endpoint', default='https://blob.mr3.simcloud.apple.com', type=str, help='S3 endpoint URL')
    parser.add_argument('--s3_path', default='', type=str, help='S3 path to model')
    parser.add_argument('--s3_save_dir', type=str, default="")

    # Wandb logging
    parser.add_argument('--log_wandb', default=False, action='store_true',
                        help='log training and validation metrics to wandb')
    parser.add_argument('--no_log_wandb', action='store_false', dest='log_wandb')
    parser.set_defaults(log_wandb=False)
    parser.add_argument('--wandb_project', default=None, type=str,
                        help='log training and validation metrics to wandb')
    parser.add_argument('--wandb_entity', default=None, type=str,
                        help='user or team name of wandb')
    parser.add_argument('--wandb_run_name', default=None, type=str,
                        help='run name on wandb')
    parser.add_argument('--log_images_wandb', action='store_true')
    parser.add_argument('--log_images_freq', default=5, type=int,
                        help="Frequency of image logging (in epochs)")
    parser.add_argument('--show_user_warnings', default=False, action='store_true')

    # Distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    parser.add_argument('--dist_on_gpu', action='store_true')

    # Parse config file if there is one
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)


    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Add the config path as a final args if given
    args.config_path = args_config.config

    return args


def setup_modality_info(args):
    """Sets up the modality info dictionary for the given domains."""
    modality_info = {mod: TRANSFER_MODALITY_INFO[mod] for mod in args.all_domains}
    return modality_info


def main(args):
    fourm.utils.init_distributed_mode(args)
    num_tasks = fourm.utils.get_world_size()
    global_rank = fourm.utils.get_rank()

    # Download pre-trained model from S3 if needed
    if args.s3_path:
        dist_download_model(args)

    device = torch.device(args.device)
    seed = args.seed + fourm.utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    if not args.show_user_warnings:
        warnings.filterwarnings("ignore", category=UserWarning)

    # Set up data type
    dtype = get_dtype(args.dtype)

    # Process domain arguments
    args.in_domains = args.in_domains.split('-')
    args.out_domains = [args.out_domains]
    args.all_domains = list(set(args.in_domains) | set(args.out_domains))
    args.num_classes_with_void = args.num_classes + 1 if args.seg_use_void_label else args.num_classes

    # Set up modality info and transforms
    modality_info = setup_modality_info(args)
    modality_transforms = TRANSFER_MODALITY_TRANSFORMS
    
    # Set up wandb logger
    if global_rank == 0 and args.log_wandb:
        log_writer = fourm.utils.WandbLogger(args)
    else:
        log_writer = None

    # Set up data loaders
    train_loader, val_loader, test_loader = setup_data_loaders(args, modality_info, modality_transforms, global_rank, num_tasks)
    batch_size_no_accum = args.batch_size * fourm.utils.get_world_size()
    num_training_steps_per_epoch = len(train_loader.dataset) // batch_size_no_accum

    # Set up model
    model = setup_model(args, device)
    model_without_ddp = model if not args.distributed else model.module

    # Set up optimizer and loss scaler
    optimizer, loss_scaler = setup_optimizer(args, model_without_ddp)
    # Set up scheduler
    lr_schedule_values, wd_schedule_values = setup_scheduler(args, num_training_steps_per_epoch)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=fourm.utils.SEG_IGNORE_INDEX)

    # Auto resume if enabled
    if args.auto_resume:
        fourm.utils.auto_load_model(
            args=args, model=model, model_without_ddp=model_without_ddp,
            optimizer=optimizer, loss_scaler=loss_scaler
        )

    # Evaluation only mode
    if args.eval:
        evaluate_model(args, model, criterion, val_loader, device, dtype)
        return

    # Test only mode
    if args.test:
        evaluate_model(args, model, criterion, test_loader, device, dtype, mode='test')
        return

    # Training loop
    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_miou = 0.0

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        if log_writer is not None:
            log_writer.set_step(epoch * num_training_steps_per_epoch)

        # Training
        train_stats = train_one_epoch(
            model=model,
            criterion=criterion,
            data_loader=train_loader,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            frozen_encoder_epochs=args.frozen_encoder_epochs,
            loss_scaler=loss_scaler,
            accum_iter=args.accum_iter,
            max_norm=args.clip_grad,
            log_writer=log_writer,
            start_steps=epoch * len(train_loader),
            lr_schedule_values=lr_schedule_values,
            wd_schedule_values=wd_schedule_values,
            in_domains=args.in_domains,
            out_domain=args.out_domains[0],
            dtype=dtype
        )

        # Save checkpoint
        if args.output_dir and args.save_ckpt > 0:
            if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
                fourm.utils.save_model(
                    args=args,
                    model=model,
                    model_without_ddp=model_without_ddp,
                    optimizer=optimizer,
                    loss_scaler=loss_scaler,
                    epoch=epoch
                )

        # Validation
        if val_loader is not None and (epoch % args.eval_freq == 0 or epoch == args.epochs - 1):
            log_images = args.log_wandb and args.log_images_wandb and (epoch % args.log_images_freq == 0)
            val_stats = evaluate_model(
                args, model, criterion, val_loader, device, dtype,
                epoch=epoch, log_images=log_images
            )

            # Save best model
            if max_miou < val_stats["mean_iou"]:
                max_miou = val_stats["mean_iou"]
                if args.output_dir and args.save_ckpt > 0:
                    fourm.utils.save_model(
                        args=args,
                        model=model,
                        model_without_ddp=model_without_ddp,
                        optimizer=optimizer,
                        loss_scaler=loss_scaler,
                        epoch="best"
                    )
            print(f'Max mIoU: {max_miou:.3f}')

            # Log stats
            log_stats = {
                **{f'train/{k}': v for k, v in train_stats.items()},
                **{f'val/{k}': v for k, v in val_stats.items()},
                'epoch': epoch,
                'n_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
            }
        else:
            log_stats = {
                **{f'train/{k}': v for k, v in train_stats.items()},
                'epoch': epoch,
                'n_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
            }

        # Update wandb logs
        if log_writer is not None:
            log_writer.update(log_stats)

        # Save logs to file
        if args.output_dir and fourm.utils.is_main_process():
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    # Print training time
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

    # Test with best checkpoint
    if test_loader is not None:
        print('Loading model with best validation mIoU')
        checkpoint = torch.load(os.path.join(args.output_dir, 'checkpoint-best.pth'), map_location='cpu')
        state_dict = {f'module.{k}': v for k, v in checkpoint['model'].items()}
        msg = model.load_state_dict(state_dict, strict=False)
        print(msg)

        print('Testing with best checkpoint')
        test_stats = evaluate_model(
            args, model, criterion, test_loader, device, dtype,
            epoch=checkpoint['epoch'], log_images=True, mode='test'
        )
        
        # Log test stats
        log_stats = {f'test/{k}': v for k, v in test_stats.items()}
        if log_writer is not None:
            log_writer.set_step(args.epochs * num_training_steps_per_epoch)
            log_writer.update(log_stats)
        if args.output_dir and fourm.utils.is_main_process():
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

def get_dtype(dtype_str):
    """Convert dtype string to torch dtype."""
    if dtype_str in ['float16', 'fp16']:
        return torch.float16
    elif dtype_str in ['bfloat16', 'bf16']:
        return torch.bfloat16
    elif dtype_str in ['float32', 'fp32']:
        return torch.float32
    elif dtype_str in ['tfloat32', 'tf32']:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        return torch.float32
    else:
        raise ValueError(f"Invalid dtype: {dtype_str}")

def dist_download_model(args):
    """Download model from S3"""
    from fourm.utils.s3_utils import download_from_s3

    file_path = args.finetune

    if file_path:
        # args.gpu is the local rank
        download = args.gpu == 0

        msg = (
            "Downloading checkpoint"
            if download
            else "Waiting for other process to download data."
        )

        print(f"{args.gpu}: {msg}.")

        if download and not os.path.isfile(file_path):

            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            download_from_s3(args, args.s3_path, args.finetune)
        
        torch.distributed.barrier()

def setup_data_loaders(args, modality_info, modality_transforms, global_rank, num_tasks):
    """Set up data loaders for training, validation and testing."""
    # Set up image augmenters
    if 'tok' not in '-'.join(args.all_domains):
        train_image_augmenter = RandomCropImageAugmenter(
            target_size=args.input_size,
            hflip=0.5,
            crop_scale=(0.2, 1.0),
            main_domain=args.main_augment_domain
        )
        val_image_augmenter = CenterCropImageAugmenter(
            target_size=args.input_size,
            hflip=0.0,
            main_domain=args.main_augment_domain
        )
    else:
        train_image_augmenter = PreTokenizedImageAugmenter(
            target_size=args.input_size,
            no_aug=not args.crop_augmentation,
            main_domain=args.main_augment_domain
        )
        val_image_augmenter = PreTokenizedImageAugmenter(
            target_size=args.input_size,
            no_aug=True,
            main_domain=args.main_augment_domain
        )

    # Set up transforms
    train_transform = transforms.Compose([
        UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=train_image_augmenter),
    ])
    val_transform = transforms.Compose([
        UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=val_image_augmenter),
    ])

    # Build datasets
    dataset_train = build_fm_transfer_dataset(
        data_path=args.data_path,
        modality_info=modality_info,
        transform=train_transform,
        modality_transforms=modality_transforms,
        all_domains=args.all_domains
    )

    dataset_val = None
    if args.eval_data_path is not None:
        dataset_val = build_fm_transfer_dataset(
            data_path=args.eval_data_path,
            modality_info=modality_info,
            transform=val_transform,
            modality_transforms=modality_transforms,
            all_domains=args.all_domains
        )

    dataset_test = None
    if args.test_data_path is not None:
        raise NotImplementedError()
    
    sampler_train = torch.utils.data.DistributedSampler(
        dataset_train,
        num_replicas=num_tasks,
        rank=global_rank,
        shuffle=True,
        drop_last=True
    )

    if args.dist_eval:
        if len(dataset_val) % num_tasks != 0:
            print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number.')
        sampler_val = torch.utils.data.DistributedSampler(
            dataset_val,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=False
        )
        if dataset_test is not None:
            sampler_test = torch.utils.data.DistributedSampler(
                dataset_test,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False
            )
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val) if dataset_val is not None else None
        sampler_test = torch.utils.data.SequentialSampler(dataset_test) if dataset_test is not None else None

    # Set up data loaders
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True
    )

    val_loader = None
    if dataset_val is not None:
        val_loader = torch.utils.data.DataLoader(
            dataset_val,
            sampler=sampler_val,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False
        )

    test_loader = None
    if dataset_test is not None:
        test_loader = torch.utils.data.DataLoader(
            dataset_test,
            sampler=sampler_test,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False
        )

    print(f"Number of samples: {len(dataset_train)}")
    return train_loader, val_loader, test_loader

def setup_model(args, device):
    """Set up the model and its components."""
    # Set up output head
    heads_dict = {
        'convnext': partial(
            ConvNeXtHead,
            preds_per_patch=args.decoder_preds_per_patch,
            depth=args.decoder_depth,
            interpolate_mode=args.decoder_interpolate_mode
        ),
    }
    output_head = heads_dict[args.output_head](
        num_classes=args.num_classes_with_void,
        img_size=args.input_size,
        embed_dim=args.decoder_dim,
        patch_size=args.patch_size
    )

    # Create model
    model = create_model(
        args.model,
        img_size=args.input_size,
        patch_size=args.patch_size,
        drop_path_rate=args.drop_path,
        encoder_norm=args.encoder_norm,
        output_head=output_head,
    )

    # Load pre-trained weights if specified
    if args.finetune:
        load_pretrained_weights(args, model)

    # Apply LoRA if specified
    if args.lora:
        inject_trainable_LoRA(
            model,
            rank=args.lora_rank,
            scale=args.lora_scale,
            target_replace_modules=get_LoRA_module_names(args.lora_modules)
        )
        model.freeze_encoder(freeze_embeddings=not args.lora_train_embeds)
        unfreeze_all_LoRA_layers(model)

    # Freeze encoder if specified
    if args.freeze_enc:
        model.freeze_encoder(freeze_embeddings=True)

    # Freeze embeddings if specified
    if args.freeze_embeds:
        model.freeze_embeddings(freeze_rgb=False)

    model.to(device)

    # Print model info
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_parameters_frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print("Model = %s" % str(model))
    print('number of trainable params: {} M'.format(n_parameters / 1e6))
    print('number of frozen params: {} M'.format(n_parameters_frozen / 1e6))
    
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=args.find_unused_params
        )

    return model

def setup_optimizer(args, model):
    """Set up optimizer and loss scaler."""
    # Set up layer decay if specified
    num_layers = model.get_num_layers()
    if args.layer_decay < 1.0:
        assigner = LayerDecayValueAssigner(
            list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)),
            is_beit3='beit3' in args.model,
        )
    else:
        assigner = None

    if assigner is not None:
        print("Assigned values = %s" % str(assigner.values))

    # Set up optimizer
    skip_weight_decay_list = model.no_weight_decay()
    print("Skip weight decay list: ", skip_weight_decay_list)

    optimizer = create_optimizer(
        args,
        model,
        skip_list=skip_weight_decay_list,
        get_num_layer=assigner.get_layer_id if assigner is not None else None,
        get_layer_scale=assigner.get_scale if assigner is not None else None
    )

    # Set up loss scaler
    loss_scaler = NativeScaler(enabled=args.dtype in ['float16', 'fp16'])

    return optimizer, loss_scaler

def setup_scheduler(args, num_training_steps_per_epoch):
    """Set up scheduler."""
    
    if args.frozen_encoder_epochs > 0:
        frozen_lr_schedule_values = fourm.utils.constant_scheduler(args.frozen_encoder_lr, args.frozen_encoder_epochs, num_training_steps_per_epoch)
        frozen_wd_schedule_values = fourm.utils.constant_scheduler(args.weight_decay, args.frozen_encoder_epochs, num_training_steps_per_epoch)
        main_schedule_epochs = args.epochs - args.frozen_encoder_epochs
    else:
        frozen_lr_schedule_values = np.array([]) 
        frozen_wd_schedule_values = np.array([])
        main_schedule_epochs = args.epochs
        
    print("Use step level LR & WD scheduler!")
    lr_schedule_values = fourm.utils.cosine_scheduler(
        args.lr, args.min_lr, main_schedule_epochs, num_training_steps_per_epoch,
        warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
    )
    if args.weight_decay_end is None:
        args.weight_decay_end = args.weight_decay
    wd_schedule_values = fourm.utils.cosine_scheduler(
        args.weight_decay, args.weight_decay_end, main_schedule_epochs, num_training_steps_per_epoch)
    
    lr_schedule_values = np.concatenate((frozen_lr_schedule_values, lr_schedule_values))
    wd_schedule_values = np.concatenate((frozen_wd_schedule_values, wd_schedule_values))
    print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
    return lr_schedule_values, wd_schedule_values

def verify_checkpoint_weights(checkpoint, model):
    """Verify that the checkpoint contains required encoder and encoder embedding weights.
    
    Args:
        checkpoint (dict): The checkpoint dictionary containing model weights
        model: The model to verify weights against
        
    Returns:
        bool: True if all required weights are present, False otherwise
    """
    required_keys = set()
    
    # Get all encoder and encoder embedding keys from the model
    for name, _ in model.named_parameters():
        if 'encoder' in name or 'embeddings' in name:
            required_keys.add(name)
    
    # Get all keys from the checkpoint
    checkpoint_keys = set(checkpoint['model'].keys())
    
    # Check if all required keys are present in the checkpoint
    missing_keys = required_keys - checkpoint_keys
    
    if missing_keys:
        print("Warning: The following required encoder/embedding weights are missing from the checkpoint:")
        for key in sorted(missing_keys):
            print(f"  - {key}")
        return False
    
    return True

def load_pretrained_weights(args, model):
    """Load pre-trained weights from checkpoint."""
    if args.finetune.startswith('https'):
        print(f'Loading web model from: {args.finetune}')
        checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu')
    elif args.finetune.endswith('.safetensors'):
        print(f'Loading local model from: {args.finetune}')
        checkpoint, _ = load_safetensors(args.finetune)
        checkpoint = {'model': checkpoint}
    else:
        print(f'Loading local model from: {args.finetune}')
        checkpoint = torch.load(args.finetune, map_location='cpu', weights_only=False)

    # Process checkpoint keys
    process_checkpoint_keys(checkpoint, args)
    
    # Verify checkpoint contains required weights
    if not verify_checkpoint_weights(checkpoint, model):
        raise ValueError("Checkpoint is missing required encoder or encoder embedding weights. Please use a valid checkpoint.")
    
    # Interpolate position embeddings if specified
    if args.interpolate_pos_emb:
        if 'beit3' in args.model:
            interpolate_pos_embed_beit(model, checkpoint['model'])
        elif 'fm' in args.model:
            interpolate_rgb_pos_emb_fm(model, checkpoint['model'])
        else:
            interpolate_pos_embed_vit(model, checkpoint['model'])

    # Load pre-trained model
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)

def process_checkpoint_keys(checkpoint, args):
    """Process checkpoint keys for compatibility."""
    def remove_res_suffix(s):
        return re.sub(r'@\d+', '', s)

    def add_res_suffix(text):
        pattern = r'(tok_[a-zA-Z0-9_]+)'
        replacement = r'\1@224'
        return re.sub(pattern, replacement, text)

    # Process resolution suffixes for internet models
    ckpt_keys = list(checkpoint['model'].keys())
    for key in ckpt_keys:
        if '@' in key:
            checkpoint['model'][remove_res_suffix(key)] = checkpoint['model'][key]
        if 'rgb' in key and 'rgb@' not in key:
            checkpoint['model'][key.replace('rgb', 'rgb@224')] = checkpoint['model'][key]
        if 'tok' in key and '@' not in key:
            checkpoint['model'][add_res_suffix(key)] = checkpoint['model'][key]

    # Reset position embeddings if specified
    if args.reset_pos_emb:
        checkpoint['model'] = {k: v for k, v in checkpoint['model'].items() if ".pos_emb" not in k}
        checkpoint['model'] = {k: v for k, v in checkpoint['model'].items() if "pos_embed" not in k}

def evaluate_model(args, model, criterion, data_loader, device, dtype, epoch=-1, log_images=False, mode='val'):
    """Evaluate the model on the given data loader."""
    model.eval()
    metric_logger = fourm.utils.MetricLogger(delimiter="  ")
    header = f'({mode.capitalize()}) Epoch: [{epoch}]'
    print_freq = 20
    max_images = 100

    seg_preds = []
    seg_gts = []
    rgb_gts = [] if log_images else None
    seg_preds_with_void = [] if log_images else None

    for x in metric_logger.log_every(data_loader, print_freq, header=header):
        tasks_dict = {
            task: tensor.to(device, non_blocking=True)
            for task, tensor in x.items()
        }

        input_samples = tasks_dict[args.in_domains[0]]

        with torch.amp.autocast('cuda', dtype=dtype, enabled=dtype != torch.float32):
            preds = model(input_samples)
            seg_pred, seg_gt = preds, tasks_dict[args.out_domains[0]]
            loss = criterion(seg_pred, seg_gt)

        loss_value = loss.item()
        seg_pred_argmax = seg_pred.argmax(dim=1)
        seg_preds.extend(list(seg_pred_argmax.cpu().numpy()))
        seg_gts.extend(list(seg_gt.cpu().numpy()))

        if log_images and len(rgb_gts) < max_images:
            rgb_gts.extend(tasks_dict['rgb@224'].cpu().unbind(0))
            seg_preds_with_void.extend(list(seg_pred.argmax(dim=1).cpu().numpy()))

        metric_logger.update(loss=loss_value)

    torch.cuda.empty_cache()

    # Log images if specified
    if log_images and fourm.utils.is_main_process():
        prefix = f'{mode}/img'
        log_semseg_wandb(
            rgb_gts,
            seg_preds_with_void,
            seg_gts,
            dataset_name=args.dataset_name,
            prefix=prefix,
            ignore_index=fourm.utils.SEG_IGNORE_INDEX
        )

    # Compute metrics
    scores = compute_metrics_distributed(
        seg_preds,
        seg_gts,
        size=len(data_loader.dataset),
        num_classes=args.num_classes,
        dataset_name=args.dataset_name,
        output_dir=args.output_dir,
        device=device,
        ignore_index=fourm.utils.SEG_IGNORE_INDEX,
        dist_on='gpu' if args.dist_on_gpu else 'cpu'
    )

    for k, v in scores.items():
        metric_logger.update(**{f"{k}": v})

    # Gather stats from all processes
    metric_logger.synchronize_between_processes()

    print(
        f'* mIoU {metric_logger.mean_iou.global_avg:.3f} '
        f'aAcc {metric_logger.pixel_accuracy.global_avg:.3f} '
        f'Acc {metric_logger.mean_accuracy.global_avg:.3f} '
        f'Loss {metric_logger.loss.global_avg:.3f}'
    )

    torch.cuda.empty_cache()
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, frozen_encoder_epochs,
                   loss_scaler, accum_iter, max_norm=0, log_writer=None, start_steps=None,
                   lr_schedule_values=None, wd_schedule_values=None, in_domains=None,
                   out_domain=None, dtype=torch.float16):
    """Train the model for one epoch."""
    model.train()
    
    # Handle frozen encoder epochs
    if frozen_encoder_epochs > 0 and epoch < frozen_encoder_epochs:
        model.module.freeze_encoder(freeze_embeddings=True)
    else:
        model.module.unfreeze_encoder(unfreeze_embeddings=True)
        
    metric_logger = fourm.utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', fourm.utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('min_lr', fourm.utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20

    for step, x in enumerate(metric_logger.log_every(data_loader, print_freq, header=header)):
        # Update learning rate and weight decay
        it = start_steps + step
        update_grad = (step + 1) % accum_iter == 0

        if step % accum_iter == 0:
            update_optimizer_params(optimizer, it, lr_schedule_values, wd_schedule_values)

        # Prepare input data
        tasks_dict = {
            task: tensor.to(device, non_blocking=True)
            for task, tensor in x.items()
        }

        input_samples = tasks_dict[in_domains[0]]

        # Forward + backward
        with nullcontext() if update_grad else model.no_sync():
            with torch.amp.autocast('cuda', enabled=dtype != torch.float32, dtype=dtype):
                preds = model(input_samples)
                seg_pred, seg_gt = preds, tasks_dict[out_domain]
                loss = criterion(seg_pred, seg_gt)

            loss_value = loss.item()
            optimizer.zero_grad()
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order

            loss = loss / accum_iter
            grad_norm = loss_scaler(
                loss,
                optimizer,
                clip_grad=max_norm,
                parameters=model.parameters(),
                create_graph=is_second_order
            )

        # Update metrics
        update_metrics(
            metric_logger,
            loss_value,
            optimizer,
            grad_norm,
            loss_scaler if dtype == torch.float16 else None
        )

        # Update wandb logs
        if log_writer is not None:
            update_wandb_logs(log_writer, loss_value, optimizer, grad_norm)

    # Gather stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {'[Epoch] ' + k: meter.global_avg for k, meter in metric_logger.meters.items()}

def update_optimizer_params(optimizer, it, lr_schedule_values, wd_schedule_values):
    """Update optimizer parameters based on schedule."""
    if lr_schedule_values is not None or wd_schedule_values is not None:
        for param_group in optimizer.param_groups:
            if lr_schedule_values is not None:
                param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
            if wd_schedule_values is not None and param_group["weight_decay"] > 0:
                param_group["weight_decay"] = wd_schedule_values[it]

def update_metrics(metric_logger, loss_value, optimizer, grad_norm, loss_scaler=None):
    """Update metrics for logging."""
    metric_logger.update(loss=loss_value)
    if loss_scaler is not None:
        metric_logger.update(loss_scale=loss_scaler.state_dict()["scale"])

    min_lr = 10.
    max_lr = 0.
    for group in optimizer.param_groups:
        min_lr = min(min_lr, group["lr"])
        max_lr = max(max_lr, group["lr"])

    metric_logger.update(lr=max_lr)
    metric_logger.update(min_lr=min_lr)

    weight_decay_value = None
    for group in optimizer.param_groups:
        if group["weight_decay"] > 0:
            weight_decay_value = group["weight_decay"]
    metric_logger.update(weight_decay=weight_decay_value)
    metric_logger.update(grad_norm=grad_norm)

def update_wandb_logs(log_writer, loss_value, optimizer, grad_norm):
    """Update wandb logs."""
    log_writer.update({
        'loss': loss_value,
        'lr': max(group["lr"] for group in optimizer.param_groups),
        'weight_decay': next((group["weight_decay"] for group in optimizer.param_groups if group["weight_decay"] > 0), None),
        'grad_norm': grad_norm,
    })
    log_writer.set_step()

def save_semseg_preds_distributed(rgb_gts, seg_preds_with_void, seg_gts, size, num_classes,
                                dataset_name, output_dir, ignore_index=fourm.utils.SEG_IGNORE_INDEX,
                                dist_on='cpu', colorize=True):
    """Save semantic segmentation predictions in a distributed manner."""
    # Collect images from all devices
    if dist_on == 'cpu':
        all_rgb_gts = collect_results_cpu(rgb_gts, size, tmpdir=None)
        all_seg_preds = collect_results_cpu(seg_preds_with_void, size, tmpdir=None)
        all_seg_gts = collect_results_cpu(seg_gts, size, tmpdir=None)
    elif dist_on == 'gpu':
        print("Collecting metrics using GPU")
        world_size = fourm.utils.get_world_size()
        all_rgb_gts = [None for _ in range(world_size)]
        all_seg_preds = [None for _ in range(world_size)]
        all_seg_gts = [None for _ in range(world_size)]
        dist.all_gather_object(all_rgb_gts, rgb_gts)
        dist.all_gather_object(all_seg_preds, seg_preds_with_void)
        dist.all_gather_object(all_seg_gts, seg_gts)
        
    if fourm.utils.is_main_process():
        save_semseg_preds(
            images=all_rgb_gts,
            preds=all_seg_preds,
            gts=all_seg_gts,
            dataset_name=dataset_name,
            save_dir=output_dir,
            image_count=None,
            colorize=colorize
        )

def compute_metrics_distributed(seg_preds, seg_gts, size, num_classes, device, dataset_name,
                              output_dir=None, ignore_index=fourm.utils.SEG_IGNORE_INDEX,
                              dist_on='cpu'):
    """Compute metrics in a distributed manner."""
    # Collect metrics from all devices
    if dist_on == 'cpu':
        all_seg_preds = collect_results_cpu(seg_preds, size, tmpdir=None)
        all_seg_gts = collect_results_cpu(seg_gts, size, tmpdir=None)
    elif dist_on == 'gpu':
        print("Collecting metrics using GPU")
        world_size = fourm.utils.get_world_size()
        all_seg_preds = [None for _ in range(world_size)]
        all_seg_gts = [None for _ in range(world_size)]
        dist.all_gather_object(all_seg_preds, seg_preds)
        dist.all_gather_object(all_seg_gts, seg_gts)

    ret_metrics_mean = torch.zeros(3, dtype=float, device=device)

    if fourm.utils.is_main_process():
        ordered_seg_preds = [result for result_part in all_seg_preds for result in result_part]
        ordered_seg_gts = [result for result_part in all_seg_gts for result in result_part]

        ret_metrics = mean_iou(
            results=ordered_seg_preds,
            gt_seg_maps=ordered_seg_gts,
            num_classes=num_classes,
            ignore_index=ignore_index
        )

        ret_metrics_mean = torch.tensor(
            [
                np.round(np.nanmean(ret_metric.astype(float)) * 100, 2)
                for ret_metric in ret_metrics
            ],
            dtype=float,
            device=device,
        )

        if output_dir is not None:
            _, mean_acc_per_class, miou_per_class = ret_metrics
            save_metrics(miou_per_class, mean_acc_per_class, dataset_name, output_dir)

    # Broadcast metrics from 0 to all nodes
    dist.broadcast(ret_metrics_mean, 0)
    pix_acc, mean_acc, miou = ret_metrics_mean
    return dict(pixel_accuracy=pix_acc, mean_accuracy=mean_acc, mean_iou=miou)

if __name__ == '__main__':
    args = get_args()

    fourm.utils.setup_run_name(args)

    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
