import datetime

import torch
import os
import argparse
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
import torch.utils.data as data
from model.dataset_exp import ClassificationDataset
from model.dataset_exp import ShapeNetDataset
from model.meshmae_exp_unused import Mesh_mae
from model.reconstruction import save_results
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule, get_cosine_schedule_with_warmup


def train(net, optim, scheduler, names, train_dataset, epoch, args, ratio):
    net.train()
    running_loss = 0
    running_f_loss = 0
    running_s_loss = 0
    n_samples = 0

    for it, (feats_patch, center_patch,coordinate_patch, face_patch,  np_Fs, label, mesh_paths) in enumerate(
            train_dataset):
        optim.zero_grad()
        faces = face_patch.to(torch.float32).cuda()
        feats = feats_patch.to(torch.float32).cuda()

        centers = center_patch.to(torch.float32).cuda()
        Fs = np_Fs.cuda()
        cordinates = coordinate_patch.to(torch.float32).cuda()
        n_samples += faces.shape[0]
        loss, f_loss, s_loss = net(faces, feats, centers, Fs, cordinates, ratio)
        loss.backward()
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=args.grad_clip if hasattr(args, 'grad_clip') else 1.0)

        # 优化步骤
        optim.step()
        scheduler.step()
        # 更新 EMA
        #  ema.update()
        running_loss += loss.item() * faces.size(0)
        running_f_loss += f_loss.item() * faces.size(0)
        running_s_loss += s_loss.item() * faces.size(0)


    epoch_loss = running_loss / n_samples
    epoch_f_loss = running_f_loss / n_samples
    epoch_s_loss = running_s_loss / n_samples


    if train.best_loss > epoch_loss:
        train.best_loss = epoch_loss
        train.best_epoch = epoch
        best_model_wts = copy.deepcopy(net.state_dict())
        torch.save(best_model_wts, os.path.join('checkpoints', names, f'loss-{epoch_loss:.4f}-{epoch:.4f}.pkl'))
    print('当前时间: {:}'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    print('epoch ({:}): {:} Train Loss: {:.4f}'.format(names, epoch, epoch_loss))
    print('epoch ({:}): {:} Train Feature Loss: {:.4f}'.format(names, epoch, epoch_f_loss))
    print('epoch ({:}): {:} Train Shape Loss: {:.4f}'.format(names, epoch, epoch_s_loss))


def test(net,  names, test_dataset, epoch, args):
    #######################################################################
    # if you are going to show the reconstruct shape, please using the following codes
    #######################################################################

    net.eval()  # 切换到评估模式
    for it, (feats_patch, center_patch, coordinate_patch, face_patch, np_Fs, label, mesh_paths) in enumerate(
            test_dataset):

        faces = face_patch.to(torch.float32).cuda()
        feats = feats_patch.to(torch.float32).cuda()
        centers = center_patch.to(torch.float32).cuda()
        Fs = np_Fs.cuda()
        cordinates = coordinate_patch.to(torch.float32).cuda()

        with torch.no_grad():
            loss, masked_indices, unmasked_indices, pred_vertices_coordinates, cordinates = net(faces, feats, centers, Fs, cordinates)
        save_results(masked_indices, unmasked_indices, pred_vertices_coordinates, cordinates, mesh_paths)

def adjust_num_masked(num_masked, divisor=64):
    """
    将num_masked调整为最接近的64的倍数。
    如果num_masked < 64，则返回0（不遮盖）。
    """
    if num_masked < divisor:
        return 0
    else:
        return int(round(num_masked / divisor)) * divisor




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', choices=['train', 'test'])
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--lr_milestones', type=str, default=None)
    parser.add_argument('--num_warmup_steps', type=str, default=None)

    parser.add_argument('--depth', type=int, required=True)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--n_dropout', type=int, default=1)
    parser.add_argument('--encoder_depth', type=int, default=6)
    parser.add_argument('--decoder_depth', type=int, default=6)
    parser.add_argument('--decoder_dim', type=int, default=512)
    parser.add_argument('--decoder_num_heads', type=int, default=6)
    parser.add_argument('--dim', type=int, default=384)
    parser.add_argument('--weight', type=float, default=0.2)
    parser.add_argument('--optim', type=str, default='adam')
    parser.add_argument('--heads', type=int, required=True)
    parser.add_argument('--patch_size', type=int, required=True)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--n_epoch', type=int, default=500)
    parser.add_argument('--max_epoch', type=int, default=300)
    parser.add_argument('--dataroot', type=str, required=True)
    parser.add_argument('--n_classes', type=int)
    parser.add_argument('--no_center_diff', action='store_true')
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--n_worker', type=int, default=52)
    parser.add_argument('--augment_scale', action='store_true')
    parser.add_argument('--augment_orient', action='store_true')
    parser.add_argument('--augment_deformation', action='store_true')
    parser.add_argument('--channels', type=int, default=10)
    parser.add_argument('--mask_ratio', type=float, default=0.25)

    args = parser.parse_args()
    mode = args.mode
    dataroot = args.dataroot

    # ========== Dataset ==========
    augments = []
    if args.augment_scale:
        augments.append('scale')
    if args.augment_orient:
        augments.append('orient')
    if args.augment_deformation:
        augments.append('deformation')

    # if 'ShapeNet' in dataroot:
    #     train_dataset = ShapeNetDataset(dataroot, train=True, augment=augments)

    if 'dataset' in dataroot:
        train_dataset = ShapeNetDataset(dataroot, train=True, augment=augments)
        # test_dataset = ShapeNetDataset(dataroot, train=False)
        # print(len(test_dataset))
        # test_data_loader = data.DataLoader(test_dataset, num_workers=args.n_worker, batch_size=args.batch_size,
        #                                    shuffle=True, pin_memory=True)
    else:
        train_dataset = ClassificationDataset(dataroot, train=True, augment=augments)
        test_dataset = ClassificationDataset(dataroot, train=False)
        print(len(test_dataset))
        test_data_loader = data.DataLoader(test_dataset, num_workers=args.n_worker, batch_size=args.batch_size,
                                           shuffle=True, pin_memory=True)
    print(len(train_dataset))
    train_data_loader = data.DataLoader(train_dataset, num_workers=args.n_worker, batch_size=args.batch_size,
                                        shuffle=True, pin_memory=True)

    # ========== Network ==========
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = Mesh_mae(masking_ratio=args.mask_ratio,
                   channels=args.channels,
                   num_heads=args.heads,
                   encoder_depth=args.encoder_depth,
                   embed_dim=args.dim,
                   decoder_num_heads=args.decoder_num_heads,
                   decoder_depth=args.decoder_depth,
                   decoder_embed_dim=args.decoder_dim,
                   patch_size=args.patch_size,
                   weight=args.weight
                   ).to(device)

    is_finetuning = args.checkpoint.lower() != 'none'

    if is_finetuning:
        # 如果提供了checkpoint，我们进入微调模式
        print(f"--- Finetuning Mode Detected (checkpoint provided) ---")
        print(f"--- Loading weights from: {args.checkpoint} ---")
        state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
        net.load_state_dict(state_dict, strict=True)
        print("--- Weights loaded successfully. ---")

    current_lr = args.lr
    warmup_steps = int(args.num_warmup_steps)
    print(f"--- Overriding LR to {current_lr} and Warmup Steps to {warmup_steps} for fine-tuning. ---")


    # ========== Optimizer (永远创建新的) ==========
    if args.optim.lower() == 'adamw':
        optim = optim.AdamW(net.parameters(), lr=current_lr, weight_decay=args.weight_decay)

    # ========== Scheduler (保留您的原始逻辑) ==========
    if args.lr_milestones.lower() != 'none':
        ms = args.lr_milestones.split()
        ms = [int(j) for j in ms]
        scheduler = MultiStepLR(optim, milestones=ms, gamma=0.1)
    else:
        scheduler = get_cosine_schedule_with_warmup(optim,
                                                    num_warmup_steps=int(args.num_warmup_steps),
                                                    num_training_steps=args.n_epoch * len(train_data_loader))  # <--- 使用正确的总步数

    print(scheduler)

    # ========== MISC ==========

    checkpoint_path = os.path.join('checkpoints', args.name)
    os.makedirs(checkpoint_path, exist_ok=True)
    #
    # if args.checkpoint.lower() != 'none':
    #     net.load_state_dict(torch.load(args.checkpoint, weights_only=True), strict=True)

    train.best_loss = 999
    train.best_epoch = 0
  #  ema = LitEma(net, decay=0.99)
    # ========== Start Training ==========

    if args.mode == 'train':
        print(args.n_worker)
        for epoch in range(args.n_epoch):
            # start_rate = 0.5
            # end_rate = 0.12
            # rate = start_rate - (start_rate - end_rate) * (epoch / args.n_epoch)
            print('epoch', epoch)
            print('当前时间: {:}'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
            print(f"Current learning rates: {[pg['lr'] for pg in optim.param_groups]}")
            train(net, optim, scheduler, args.name, train_data_loader, epoch, args, args.mask_ratio)
            print('train finished')


    else:
        test(net,  args.name, test_data_loader, 0, args)
