import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

import timm
import timm.optim as optim_factory

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

# 引入你双掩码自监督模型文件，
# 假设里面定义了 mae_vit_1d_dualmask_256(x_noisy, x_clean)
  # 你给的文件名
import DE


from engine_pretrain_de import train_one_epoch_de
#####################################
# 1) 新的数据集：TimeSeriesDatasetPair
#####################################
class TimeSeriesDatasetPair(torch.utils.data.Dataset):
    """
    同时加载 "噪声数据" 和 "干净数据" 的 Dataset
    需保证两边文件一一对应
    """
    def __init__(self, root_noisy, root_clean, transform=None):
        self.root_noisy = Path(root_noisy)
        self.root_clean = Path(root_clean)
        self.transform = transform

        self.files_noisy = sorted(list(self.root_noisy.glob("*.npy")))
        self.files_clean = sorted(list(self.root_clean.glob("*.npy")))
        assert len(self.files_noisy) == len(self.files_clean), \
            "噪声文件数量与干净文件数量不匹配！"

    def __len__(self):
        return len(self.files_noisy)

    def __getitem__(self, index):
        data_noisy = np.load(self.files_noisy[index])  # shape:(256,)
        data_noisy = torch.tensor(data_noisy, dtype=torch.float32).unsqueeze(0)  # [1,256]

        data_clean = np.load(self.files_clean[index])  # shape:(256,)
        data_clean = torch.tensor(data_clean, dtype=torch.float32).unsqueeze(0) # [1,256]

        if self.transform:
            data_noisy = self.transform(data_noisy)
            # data_clean = self.transform(data_clean) # 如果也需要transform

        return data_noisy, data_clean


#####################################
# 2) add_weight_decay 函数保持不变
#####################################
def add_weight_decay(model, weight_decay=0.05, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # 冻结参数跳过
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay,   'weight_decay': weight_decay}
    ]

#####################################
# 3) 参数解析器: 增加 data_path_noisy / data_path_clean
#####################################
def get_args_parser():
    parser = argparse.ArgumentParser('MAE Noise Training', add_help=False)

    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=41, type=int)
    parser.add_argument('--accum_iter', default=1, type=int)

    # 模型
    parser.add_argument('--model', default='', type=str,
                        help='Name of model in model_noisemae_o')

    # 数据集输入路径: noisy & clean
    parser.add_argument('--data_path_noisy', default=r'', type=str,
                        help='path to NOISY time series dataset (npy files)')
    parser.add_argument('--data_path_clean', default=r'', type=str,
                        help='path to CLEAN time series dataset (npy files)')

    parser.add_argument('--output_dir', default='',
                        help='path where to save checkpoints')
    parser.add_argument('--log_dir', default='',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int)
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true')
    parser.set_defaults(pin_mem=True)

    # 优化器超参
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=None)
    parser.add_argument('--blr', type=float, default=1e-3)
    parser.add_argument('--min_lr', type=float, default=0.)
    parser.add_argument('--warmup_epochs', type=int, default=0)

    # 分布式
    parser.add_argument('--world_size', default=1, type=int)
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://')

    return parser

#####################################
# 4) load_checkpoint (可保持原样)
#####################################
def load_checkpoint(args, model_without_ddp, optimizer, loss_scaler, device):
    if args.resume and os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume, map_location=device, weights_only=False)
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if loss_scaler is not None and 'loss_scaler' in checkpoint and checkpoint['loss_scaler'] is not None:
            loss_scaler.load_state_dict(checkpoint['loss_scaler'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        print(f"Resumed checkpoint from {args.resume}, starting from epoch {start_epoch}")
        return start_epoch
    else:
        print("No valid resume checkpoint, start from epoch 0")
        return 0

#####################################
# 5) main
#####################################
def main(args):
    misc.init_distributed_mode(args)
    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("Arguments:\n{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    transform_train = None

    # 用 TimeSeriesDatasetPair
    dataset_train = TimeSeriesDatasetPair(
        root_noisy=args.data_path_noisy,
        root_clean=args.data_path_clean,
        transform=transform_train
    )
    print("Dataset with pair noisy+clean:", dataset_train)

    if args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train =", str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if misc.get_rank() == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    # 构建模型: mae_vit_1d_dualmask_256
    model =DE .__dict__[args.model]()
    model.to(device)
    model_without_ddp = model

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    if args.lr is None:
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)
    print("accumulate grad iterations:", args.accum_iter)
    print("effective batch size:", eff_batch_size)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], find_unused_parameters=True
        )
        model_without_ddp = model.module

    param_groups = add_weight_decay(model_without_ddp, weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    print("Optimizer:", optimizer)
    loss_scaler = NativeScaler()

    start_epoch = load_checkpoint(args, model_without_ddp, optimizer, loss_scaler, device)
    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        # ---- train_one_epoch ----
        train_stats = train_one_epoch_de(
            model, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            log_writer=log_writer,
            args=args
        )

        # 保存 checkpoint
        if args.output_dir and ((epoch % 5 == 0) or (epoch + 1 == args.epochs)):
            misc.save_model(
                args=args, model=model,
                model_without_ddp=model_without_ddp,
                optimizer=optimizer,
                loss_scaler=loss_scaler,
                epoch=epoch
            )

        # 记录日志
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time:', total_time_str)





if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()

    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    main(args)
