# main_dil.py

import argparse
import psutil
def use_cpus(gpus: list, cpus_per_gpu=24):
    """
    Set process to use specific CPU cores
    
    Args:
        gpus: List of GPU IDs, e.g., [0,1] means using GPU 0 and 1
        cpus_per_gpu: Number of CPU cores allocated per GPU
    """
    cpus = []
    for gpu in gpus:
        gpu=int(gpu)
        cpus.extend(list(range(gpu * cpus_per_gpu, (gpu + 1) * cpus_per_gpu)))
    
    p = psutil.Process()
    p.cpu_affinity(cpus)
    
    print(f"Using {len(cpus)} CPU cores: {cpus}")
    print(f"Ensure num_workers is less than CPU core count")

import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import WeightedRandomSampler

import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_continual_dataloader
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download

from dil_models import DomainIncrementalModel
from engine_dil import train_and_evaluate_dil
from fft_utils import compute_and_save_amp_key

from methods.continual_iterative_model import create_iterative_token_routed_vit as create_base_vit

def get_args_parser():
    parser = argparse.ArgumentParser('DIL MAE fine-tuning', add_help=False)
    
    parser.add_argument('--train_mode', type=str, default='domain_incremental', choices=['domain_incremental'])
    parser.add_argument('--domains', nargs='+', default=['APTOS', 'IDRiD', 'Messidor2'],
                        help='List of domain names for the DIL training sequence.')
    parser.add_argument('--dil_method', type=str, default='fft', 
                        choices=['fft', 'key_value', 'key_value_learnable', 'kmeans'],
                        help='Inference-time domain selection method.')

    parser.add_argument('--pull_constraint', action='store_true', default=True,
                        help='Enable the pull constraint auxiliary loss for learnable keys.')
    parser.add_argument('--pull_constraint_coeff', type=float, default=0.1,
                        help='Coefficient for the pull constraint loss.')

    parser.add_argument('--kmeans_n_clusters', type=int, default=1,
                        help='Number of clusters (keys) per domain for the K-Means method.')

    parser.add_argument('--transfer_weights', action='store_true',
                        help='If enabled, initialize new task modules with weights from the previous task.')
    parser.set_defaults(transfer_weights=False)

    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs_per_domain', default=50, type=int, help='Number of epochs to train on each domain.')
    parser.add_argument('--accum_iter', default=1, type=int)
    parser.add_argument('--use_class_balance_sampler', action='store_true', default=True, help='Use class-balance sampler for each domain.')
    
    # Model parameters
    parser.add_argument('--model', default='vit_large_patch16_224', type=str, help='Name of base model.')
    parser.add_argument('--input_size', default=224, type=int)
    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT')
    parser.add_argument('--target_block_indices', type=int, nargs='+', default=[23],
                        help='A list of specific block indices to replace.')
    parser.add_argument('--lora_rank', type=int, default=16)
    parser.add_argument('--num_recursion_steps', type=int, default=3)
    
    # Optimizer parameters
    parser.add_argument('--clip_grad', type=float, default=None)
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=None)
    parser.add_argument('--blr', type=float, default=1e-3, help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--layer_decay', type=float, default=0.75)
    parser.add_argument('--min_lr', type=float, default=1e-6)
    parser.add_argument('--warmup_epochs', type=int, default=1)

    # Loss and Regularization
    parser.add_argument('--use_focal_loss', action='store_true', default=True)
    parser.add_argument('--sparsity_target', type=float, default=0.3)
    parser.add_argument('--sparsity_lambda', type=float, default=0.01)

    # Finetuning params
    parser.add_argument('--finetune_repo', default='YukunZhou/RETFound_mae_natureCFP', type=str, help='HuggingFace repo for pre-trained weights.')
    parser.add_argument('--task', default='DIL_FFT_DR', type=str, help='A name for this experiment task.')
    
    # Dataset parameters
    parser.add_argument('--data_path', default='./data/', type=str)
    parser.add_argument('--nb_classes', default=5, type=int, help='Number of classes (should be consistent across domains).')
    parser.add_argument('--output_dir', default='./output_dil')
    parser.add_argument('--log_dir', default='./logs_dil')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume_dir', default='', help='Resume DIL training from a checkpoint directory.')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--pin_mem', action='store_true', default=True)

    # Augmentation parameters
    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')
    # Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # Mixup params
    parser.add_argument('--mixup', type=float, default=0,
                        help='mixup alpha, mixup enabled if > 0.')
    parser.add_argument('--cutmix', type=float, default=0,
                        help='cutmix alpha, cutmix enabled if > 0.')
    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup_prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup_mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
    # Finetuning params
    parser.add_argument('--finetune', default='RETFound_mae_natureCFP', type=str,
                        help='finetune from checkpoint')
    parser.add_argument('--cls_token', action='store_false', dest='global_pool',
                        help='Use class token instead of global pool for classification')
    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    # fine-tuning parameters
    parser.add_argument('--savemodel', action='store_true', default=True,
                        help='Save model')
    parser.add_argument('--use_early_stopping', action='store_true', default=True,
                        help='Use early stopping')
    parser.add_argument('--early_stopping_patience', type=int, default=10,
                        help='Early stopping patience')
    parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
    parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
    parser.add_argument('--datasets_seed', default=2026, type=int)
    parser.add_argument('--gpus', type=str, default='3', help='GPUs to use')

    return parser

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main(args):
    set_seed(args.seed)
    misc.init_distributed_mode(args)
    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))
    
    device = torch.device(args.device)
    
    continual_dataloader = build_continual_dataloader(args)
    
    print("Creating base ViT model...")
    base_model = create_base_vit(
        model_name=args.model,
        pretrained=True,
        num_classes=args.nb_classes,
        drop_path_rate=args.drop_path,
        global_pool='token'
    )
    model = DomainIncrementalModel(args, base_model)
    if args.finetune_repo:
        print(f"Loading pre-trained weights from HuggingFace repo: {args.finetune_repo}")
        checkpoint_path = hf_hub_download(repo_id=args.finetune_repo, filename=f'{args.finetune}.pth')
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        
        checkpoint_model = checkpoint['model']
        
        checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
        checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
        checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
        
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        from util.pos_embed import interpolate_pos_embed
        interpolate_pos_embed(model, checkpoint_model)

        msg = model.load_state_dict(checkpoint_model, strict=False)
        print("Pre-trained weights loaded with message:", msg)
        print("Missing keys:", msg.missing_keys)
        
    print("Wrapping base model in DomainIncrementalModel...")
    
    model.to(device)
    print(f"Setting active domain to: {args.domains[0]}")
    model.set_active_domain(args.domains[0])
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total_params/1e6:.2f}M")
    print(f"Trainable params: {trainable_params/1e6:.2f}M ({trainable_params/total_params*100:.4f}%)")
    
    if trainable_params == 0:
        raise RuntimeError("No trainable parameters found! Check the model setup.")

    model_without_ddp = model
    if args.distributed:
        if args.local_rank >= 0:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
        else:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    print(f"Start DIL training for {len(args.domains)} domains using method: {args.dil_method}.")
    if args.transfer_weights:
        print("Weight transfer is ENABLED - new tasks will inherit weights from previous tasks.")
    else:
        print("Weight transfer is DISABLED - new tasks will use random initialization.")
    
    log_writer = None
    if misc.is_main_process():
        log_dir = os.path.join(args.log_dir, args.task)
        os.makedirs(log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=log_dir)

    max_score = 0.0
    amp_keys = {}
    amp_keys_dir = os.path.join(args.output_dir, args.task, "amp_keys")
    os.makedirs(amp_keys_dir, exist_ok=True)
    
    start_time = time.time()
    loss_scaler = NativeScaler()

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    if args.lr is None:
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    train_and_evaluate_dil(
        model, model_without_ddp, continual_dataloader,
        loss_scaler, device, log_writer, args
    )
    
    total_time = time.time() - start_time
    print(f'Total DIL Training time {datetime.timedelta(seconds=int(total_time))}')

if __name__ == '__main__':
    args = get_args_parser().parse_args()
    if args.output_dir:
        Path(os.path.join(args.output_dir, args.task)).mkdir(parents=True, exist_ok=True)
    use_cpus(gpus=[args.gpus])
    main(args)
