
import datetime
import os
import sys
from functools import partial
from typing import List, Tuple, Callable

import pytz
import torch
import torch.distributed as tdist
import torch.multiprocessing as tmp
from timm import create_model
from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy
from timm.optim import AdamW, Lamb
from timm.utils import ModelEmaV2
from torch.nn.parallel import DistributedDataParallel
from torch.optim.optimizer import Optimizer

from arg import FineTuneArgs
from downstream_imagenet.mixup import BatchMixup
from lr_decay import get_param_groups


def time_str(for_dirname=False):
    return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')


def init_distributed_environ():
    # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
    if tmp.get_start_method(allow_none=True) is None:
        tmp.set_start_method('spawn')
    global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
    local_rank = global_rank % num_gpus
    torch.cuda.set_device(local_rank)
    
    tdist.init_process_group(backend='nccl')
    assert tdist.is_initialized(), 'torch.distributed is not initialized!'
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    
    # print only when local_rank == 0 or print(..., force=True)
    import builtins as __builtin__
    builtin_print = __builtin__.print
    
    def prt(msg, *args, **kwargs):
        force = kwargs.pop('force', False)
        if local_rank == 0 or force:
            f_back = sys._getframe().f_back
            file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
            builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs)
    
    __builtin__.print = prt
    tdist.barrier()
    return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device


def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]:
    num_classes = 1000
    model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device)
    model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M'
    # create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device)
    if args.sbn:
        model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp)
    print(f'[模型={args.model}] [#参数={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n') # Changed to Chinese
    model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False)
    model.train()
    opt_cls = {
        'adam': AdamW, 'adamw': AdamW,
        'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False),
    }
    param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale)
    # param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float}
    optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0)
    print(f'[优化器={type(optimizer)}]') # Changed to Chinese
    mixup_fn = BatchMixup(
        mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=1.0, switch_prob=0.5, mode='batch',
        label_smoothing=0.1, num_classes=num_classes
    )
    mixup_fn.mixup_enabled = args.mixup > 0.0
    if 'lamb' in args.opt:
        # label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0
        criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None)
    else:
        criterion = SoftTargetCrossEntropy()
    print(f'[损失函数] {criterion}') # Changed to Chinese
    print(f'[Mixup函数] {mixup_fn}') # Changed to Chinese
    return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer


def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
    # 初始化加载信息字典
    load_info = {
        'loaded_keys': 0,
        'missing_keys': [],
        'unexpected_keys': []
    }

    if len(resume_from) == 0 or not os.path.exists(resume_from):
        print(f'[警告] 检查点文件 `{resume_from}` 未找到。将从头开始训练。', force=True) # Changed to Chinese, added force=True
        # 返回默认值，并带上空 load_info
        return 0, '[无性能描述]', load_info 
        # 原始代码是 raise AttributeError，但为了符合main.py的逻辑，改为返回默认值

    print(f'[正在尝试从文件 `{resume_from}` 恢复...]') # Changed to Chinese
    checkpoint = torch.load(resume_from, map_location='cpu')
    
    # 检查是否为预训练检查点
    assert checkpoint.get('is_pretrain', False) == False, '请勿使用 `*_withdecoder_1kpretrained_spark_style.pth`，此文件仅用于恢复预训练。请使用 `*_1kpretrained_timm_style.pth` 或 `*_1kfinetuned*.pth`。' # Changed to Chinese
    
    # 获取 epoch 和性能描述，使用 .get() 避免键不存在报错
    ep_start = checkpoint.get('epoch', -1) + 1 
    performance_desc = checkpoint.get('performance_desc', '[无性能描述]') # Changed to Chinese

    # 加载模型状态字典并记录缺失/意外的键
    model_state_dict = model_without_ddp.state_dict()
    # 从 checkpoint 中获取模型状态，并尝试加载
    ckpt_model_state = checkpoint.get('module', checkpoint) # 兼容旧的检查点格式，可能直接是根目录的module
    
    # 获取加载的键
    loaded_model_keys = set(ckpt_model_state.keys())
    
    # 识别缺失和意外的键
    missing_keys = [k for k in model_state_dict.keys() if k not in loaded_model_keys]
    unexpected_keys = [k for k in loaded_model_keys if k not in model_state_dict.keys()]

    # 实际加载模型状态字典，使用 strict=False 以允许部分加载
    model_without_ddp.load_state_dict(ckpt_model_state, strict=False)

    # 更新 load_info
    load_info['loaded_keys'] = len(set(model_state_dict.keys()) - set(missing_keys)) # 实际加载成功的键数量
    load_info['missing_keys'] = missing_keys
    load_info['unexpected_keys'] = unexpected_keys

    print(f'[检查点加载] 缺失的键: {missing_keys}') # Changed to Chinese
    print(f'[检查点加载] 意外的键: {unexpected_keys}') # Changed to Chinese
    print(f'[检查点加载] 起始Epoch: {ep_start}, 性能描述: {performance_desc}') # Changed to Chinese
    
    if 'optimizer' in checkpoint and optimizer is not None: # 检查 optimizer 是否为 None
        try:
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f'[检查点加载] 优化器状态已加载。') # Chinese log
        except Exception as e:
            print(f'[警告] 优化器状态加载失败: {e}', force=True) # Chinese log

    if 'ema' in checkpoint and ema_module is not None: # 检查 ema_module 是否为 None
        try:
            ema_module.load_state_dict(checkpoint['ema'])
            print(f'[检查点加载] EMA模型状态已加载。') # Chinese log
        except Exception as e:
            print(f'[警告] EMA模型状态加载失败: {e}', force=True) # Chinese log

    return ep_start, performance_desc, load_info # 现在返回3个值


def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state):
    checkpoint_path = os.path.join(args.exp_dir, save_to)
    if args.is_local_master:
        to_save = {
            'args': str(args),
            'arch': args.model,
            'epoch': epoch,
            'performance_desc': performance_desc,
            'module': model_without_ddp_state,
            'ema': ema_state,
            'optimizer': optimizer_state,
            'is_pretrain': False,
        }
        torch.save(to_save, checkpoint_path)