import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
# 假设你的一维时间序列模型放在 models_temae_256 模块中
import ORE
from engine_pretrain_ore import train_one_epoch_ore


# -------------------------------
# 自定义数据集：加载时间序列数据（每个样本为 .npy 文件，形状(512,)）
# -------------------------------
class TimeSeriesDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.root = Path(root)
        self.files = sorted(list(self.root.glob("*.npy")))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        data = np.load(self.files[index])  # 期望 shape: (512,)
        data = torch.tensor(data, dtype=torch.float32).unsqueeze(0)  # 转为 [1,512]
        if self.transform:
            data = self.transform(data)
        # 返回 (data, dummy)
        return data, 0


# -------------------------------
# 参数解析器
# -------------------------------
def add_weight_decay(model, weight_decay=0.05, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # 冻结参数跳过
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [{'params': no_decay, 'weight_decay': 0.},
            {'params': decay, 'weight_decay': weight_decay}]


def get_args_parser():
    parser = argparse.ArgumentParser('MAE Pre-training for 1D Time Series', add_help=False)
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus)')
    parser.add_argument('--epochs', default=400, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations to increase effective batch size')
    # 模型参数
    parser.add_argument('--model', default='mae_vit_1d_patchmean_base', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--seq_len', default=256, type=int,
                        help='Length of input time series (default: 256)')
    parser.add_argument('--mask_ratio', default=0.75, type=float,
                        help='Masking ratio (percentage of removed patches).')
    parser.add_argument('--norm_pix_loss', action='store_true',
                        help='Use normalized pixels (per patch) as targets for computing loss')
    parser.set_defaults(norm_pix_loss=False)
    # 优化器参数
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
                        help='base learning rate: absolute_lr = blr * effective_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N',
                        help='epochs to warmup LR')
    # 数据集参数
    parser.add_argument('--data_path', default=r'', type=str,
                        help='path to time series dataset (expects .npy files)')
    parser.add_argument('--output_dir', default='600',
                        help='path where to save checkpoints (empty for no saving)')
    parser.add_argument('--log_dir', default='600',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')  # 指定恢复的 checkpoint 文件路径
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient transfer to GPU')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    # 分布式训练参数
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    return parser


def load_checkpoint(args, model_without_ddp, optimizer, loss_scaler, device):
    """
    如果 args.resume 非空且文件存在，则加载 checkpoint，
    更新模型、优化器和 loss_scaler 状态，并返回起始 epoch。
    """
    if args.resume and os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume, map_location=device, weights_only=False)
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if loss_scaler is not None and 'loss_scaler' in checkpoint and checkpoint['loss_scaler'] is not None:
            loss_scaler.load_state_dict(checkpoint['loss_scaler'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        print(f"Resumed checkpoint from {args.resume}, starting from epoch {start_epoch}")
        return start_epoch
    else:
        print("No valid resume checkpoint provided, starting from epoch 0")
        return 0


def main(args):
    misc.init_distributed_mode(args)
    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("Arguments:\n{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    transform_train = None

    dataset_train = TimeSeriesDataset(root=args.data_path, transform=transform_train)
    print("Training dataset:", dataset_train)

    if args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train =", str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if misc.get_rank() == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    # 定义模型，调用你的一维预训练模型
    model = ORE.__dict__[args.model]()
    model.to(device)
    model_without_ddp = model

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    if args.lr is None:
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)
    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    # 设置优化器
    param_groups = add_weight_decay(model_without_ddp, weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    print("Optimizer:", optimizer)
    loss_scaler = NativeScaler()
    # 在这里添加CosineAnnealingLR调度器，T_max取总epoch数，最小学习率为 args.min_lr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.min_lr)
    # 加载 checkpoint（断点续训）
    start_epoch = load_checkpoint(args, model_without_ddp, optimizer, loss_scaler, device)

    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        train_stats = train_one_epoch_ore(
            model, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            log_writer=log_writer,
            args=args
        )
        # 每个epoch结束后更新学习率调度器
        scheduler.step()
        if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs):
            misc.save_model(
                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                loss_scaler=loss_scaler, epoch=epoch)
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)