# train.py
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from config_numerical import config
from dataloader_05res import train_Dataset, test_Dataset
from baseline.UNet import UNet
from baseline.SimVP import SimVP
from baseline.DiT import Dit_plus
from baseline.AsNet import Assimilation
from baseline.ERANet import ERANet
#from qgm_functional import init_qgm_state, compute_q_over_f0_from_p, step_qgm

class QGMTrainer:
    def __init__(self, device):
        # 系统配置
        self.device = device
        self.num_gpus = torch.cuda.device_count()
        self.rank = dist.get_rank()
        
        # 物理参数初始化
        #self.state = init_qgm_state(config.QGM_PARAMS)
        self.param = config.QGM_PARAMS
        
        # 模型组件
        self.model = self._init_model()
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=config.INIT_LR)
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, 
            step_size=config.LR_STEP_SIZE,
            gamma=config.LR_GAMMA
        )
        
        # 数据加载
        self.train_loader, self.val_loader = self._prepare_dataloaders()

    def _init_model(self):
        """初始化并包装DDP模型"""
        #model_gm = UNet(**config.MODEL_GM).to(self.device)
        #model_S = UNet(**config.MODEL_S).to(self.device)
        model = ERANet(in_c=12, out_c=1).to(self.device)
        # model_gm = Dit_plus(shape_in=(1, 1, 3, 360, 720), shape_out=(1, 1, 2, 360, 720), hid_S=32, hid_T=64, N_S=4, N_T=8, time_step=1000).to(self.device)
        # model_S = Dit_plus(shape_in=(1, 1, 7, 360, 720), shape_out=(1, 1, 1, 360, 720), hid_S=32, hid_T=64, N_S=4, N_T=8, time_step=1000).to(self.device)
        # model = Assimilation(in_c=12, out_c=1, channel=64).to(self.device)
        # model = SimVP(shape_in=[1, 12, 360, 720], shape_out=[1, 1, 360, 720]).to(self.device)
        #model_gm = SimVP(shape_in=[1, 3, 360, 720], shape_out=[1, 2, 360, 720]).to(self.device)
        #model_S = SimVP(shape_in=[1, 7, 360, 720], shape_out=[1, 1, 360, 720]).to(self.device)
        return torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.device.index], broadcast_buffers=False, find_unused_parameters=True)

    def _prepare_dataloaders(self):
        """创建分布式数据加载器"""
        train_set = train_Dataset(config.DATA)
        val_set = test_Dataset(config.DATA)
        
        return (
            DataLoader(
                train_set,
                batch_size=config.BATCH_SIZE,
                sampler=DistributedSampler(train_set),
                num_workers=config.NUM_WORKERS,
                pin_memory=True
            ),
            DataLoader(
                val_set,
                batch_size=config.BATCH_SIZE,
                sampler=DistributedSampler(val_set),
                num_workers=config.NUM_WORKERS,
                pin_memory=True,
                shuffle=False
            )
        )

    def _sync_tensor(self, tensor):
        """跨GPU张量同步"""
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        return rt / self.num_gpus
        
    def _train_epoch(self):
        """单epoch训练"""
        self.model.train()
        total_loss = 0.0
        
        for input_ssta, target_ssta in tqdm(self.train_loader, desc="Training", disable=self.rank!=0):
            input_ssta = input_ssta.to(self.device, non_blocking=True).float()
            target_ssta = target_ssta.to(self.device, non_blocking=True).float()
            
            self.optimizer.zero_grad()
            # 多步物理模拟
            pred_ssta = self.model(input_ssta)
            #print(pred_ssta.shape, input_ssta[:,1:].shape)
            loss = self.criterion(pred_ssta, target_ssta)
            
            # 反向传播
            loss.backward()
            self.optimizer.step()
            
            # 损失同步
            sync_loss = self._sync_tensor(loss)
            total_loss += sync_loss.item() * input_ssta.size(0)
            
        return total_loss / len(self.train_loader.dataset)

    def _validate(self):
        """验证流程"""
        self.model.eval()
        total_loss = 0.0
        
        with torch.no_grad():
            for input_ssta, target_ssta in tqdm(self.val_loader, desc="Validating", disable=self.rank!=0):
                input_ssta = input_ssta.to(self.device, non_blocking=True).float()
                target_ssta = target_ssta.to(self.device, non_blocking=True).float()
                
                pred_ssta = self.model(input_ssta)
                #print(pred_ssta.shape, input_ssta[:,1:].shape)
                loss = self.criterion(pred_ssta, target_ssta)
                
                sync_loss = self._sync_tensor(loss)
                total_loss += sync_loss.item() * input_ssta.size(0)
                
        return total_loss / len(self.val_loader.dataset)

    def _load_checkpoint(self):
        """加载预训练模型权重"""
        ckpt_path = f"{config.CKPT_PATH_PRE}_best_model.pth"
        if not os.path.exists(ckpt_path):
            if self.rank == 0:
                logging.warning(f"No checkpoint found at {ckpt_path}, training from scratch")
            return

        if self.rank == 0:
            state_dict = torch.load(ckpt_path, map_location=self.device)       
            # 处理多GPU参数名前缀
            if all(k.startswith('module.') for k in state_dict.keys()):
                state_dict = {k[7:]: v for k, v in state_dict.items()}
            
            self.model.module.load_state_dict(state_dict)
            logging.info(f"Successfully loaded checkpoint from {ckpt_path}")
        
    def _test(self):
        """测试流程"""
        self.model.eval()
        test_loss = 0.0
        all_inputs = []
        all_targets = []
        all_outputs = []

        with torch.no_grad():
            for input_ssta, target_ssta in tqdm(self.val_loader, desc="Validating", disable=self.rank!=0):
                print(input_ssta.shape, target_ssta.shape)
                input_ssta = input_ssta.to(self.device, non_blocking=True).float()
                target_ssta = target_ssta.to(self.device, non_blocking=True).float()
                
                pred_ssta = self.model(input_ssta)
            
                # Convert tensors to numpy arrays and append to lists
                all_inputs.append(input_ssta.cpu().numpy())
                all_outputs.append(pred_ssta.cpu().numpy())
                all_targets.append(target_ssta.cpu().numpy())
                
        all_inputs = np.concatenate(all_inputs, axis=0)
        all_outputs = np.concatenate(all_outputs, axis=0)
        all_targets = np.concatenate(all_targets, axis=0)

        np.save(f'./results/{config.BACKBONE}_inputs.npy', all_inputs)
        np.save(f'./results/{config.BACKBONE}_outputs.npy', all_outputs)
        np.save(f'./results/{config.BACKBONE}_targets.npy', all_targets)
        

    def run(self):
        """主训练循环"""
        best_loss = float('inf')

        # self._load_checkpoint()
        # self._test()
        
        for epoch in range(config.EPOCHS):
            # 设置分布式采样器
            self.train_loader.sampler.set_epoch(epoch)
            
            # 训练阶段
            train_loss = self._train_epoch()
            self.scheduler.step()
            
            # 验证阶段
            val_loss = self._validate()
            
            # 主进程保存结果
            if self.rank == 0:
                self._save_checkpoint(val_loss < best_loss)
                best_loss = min(val_loss, best_loss)
                self._log_progress(epoch, train_loss, val_loss)

    def _save_checkpoint(self, is_best):
        """保存模型检查点"""
        if is_best:
            torch.save(self.model.module.state_dict(), f"{config.CKPT_PATH}_best_model.pth")

    def _log_progress(self, epoch, train_loss, val_loss):
        """记录训练进度"""
        lr = self.optimizer.param_groups[0]['lr']
        log_msg = (
            f"Epoch {epoch+1:03d} | "
            f"LR: {lr:.2e} | "
            f"Train: {train_loss*self.num_gpus:.7f} | "
            f"Val: {val_loss*self.num_gpus:.7f}"
        )
        logging.info(log_msg)
        print(log_msg)  # 控制台输出

def main():
    """主函数"""
    # 分布式环境初始化
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    
    # 设置随机种子
    torch.manual_seed(config.SEED + local_rank)
    np.random.seed(config.SEED + local_rank)

    if local_rank == 0:
        logging.info(f"Training configuration:\n{vars(config)}")
        os.makedirs(config.CKPT_PATH, exist_ok=True)
        
    trainer = QGMTrainer(device)
    trainer.run()
    
    dist.destroy_process_group()

if __name__ == "__main__":
    main()