import os
import torch
import torch.nn as nn
import os.path as osp
import torch.nn.functional as F
from tqdm.auto import tqdm
from utils import *
from svdiffusion import SVDiffusion
from unet1d import Unet1d
import numpy as np
from torch.utils.data import DataLoader, Dataset
from itertools import chain
import matplotlib.pyplot as plt
import h5py
from torch.utils.tensorboard import SummaryWriter


class Trainer():
    def __init__(self, channels, epochs, diffuser, train_loader, device=None):
        # super().__init__()
        self.device = device
        self.diffuser = diffuser
        self.T = self.diffuser.time_steps
        self.forward_diffusion_sample = self.diffuser.forward
        self.unet = self.diffuser.model
        self.sampler = self.diffuser.sampling_sequence
        self.channels = channels

        self.model_save_dir = f"results_iclr/pretrain_{epochs}epoch"
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        self.epochs = epochs
        self.train_loader = train_loader

        self.optimizer = torch.optim.AdamW(self.unet.parameters(), lr=5e-4, weight_decay=0.01)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.epochs*len(self.train_loader))

    def get_loss(self, data, t):
        U_noisy, noise = self.forward_diffusion_sample(data, t)
        noise_pred = self.unet(U_noisy, t)
        return F.l1_loss(noise, noise_pred)

    def save_model_weight(self, epoch):
        torch.save({
            'unet': self.unet.state_dict()
        }, f'{self.model_save_dir}/model_{epoch}.pt')

    def train(self):
        logfilename = osp.join(self.model_save_dir, 'train.log')
        logger = get_logger(logfilename)
        ravg = RunningAverage()
        writer = SummaryWriter(log_dir=self.model_save_dir)

        for epoch in tqdm(range(1, self.epochs+1)):
            epoch_loss = 0
            samples = 0
            for data in self.train_loader:
                self.optimizer.zero_grad()
                data = data.float().to(self.device)
                
                t = torch.randint(0, self.T, (data.shape[0],)).to(self.device).long()

                loss = self.get_loss(data, t)
                loss.backward()
                epoch_loss += loss.item() * data.shape[0]
                samples += data.shape[0]
                self.optimizer.step()
                self.scheduler.step()

            if epoch % 40 == 0:
                self.save_model_weight(epoch)
            epoch_loss /= samples

            line = f'epoch_loss: {epoch_loss}'
            writer.add_scalar('Loss/train', epoch_loss, epoch)
            line += ravg.info()
            logger.info(line)
            ravg.reset()




class HDF5Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]  # 返回单个样本

class HDF5MultiGroupDataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        self.file = None  # h5py file
        self.index_map = []  # 每个样本对应 group 和 group 内的 index

        # 建立索引映射
        with h5py.File(h5_path, 'r') as f:
            for group_name in f.keys():
                group = f[group_name]
                num_samples = len(group['data'])
                for i in range(num_samples):
                    self.index_map.append((group_name, i))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        if self.file is None:
            self.file = h5py.File(self.h5_path, 'r')  # lazy init in worker

        group_name, inner_idx = self.index_map[idx]
        group = self.file[group_name]

        data = group['data'][inner_idx]

        return torch.from_numpy(data)
import argparse

if __name__ == '__main__':  
    parser = argparse.ArgumentParser()
    # parser.add_argument("--gpu", type=int, default=1, choices=[0, 1, 2, 3])
    parser.add_argument("--dim_mults", default=(1, 2, 3))
    parser.add_argument("--init_dim", default=128, type=int)
    parser.add_argument("--epochs", default=40, type=int)
    args = parser.parse_args()

    channels = 95
    h5_path = '../data/new_U_data_order.hdf5' #All datasets used for pre-training are decomposed via SVD using the prepare.py script, with the resulting U matrices stored in the new_U_data_order.hdf5 file.
    data_list = []
    with h5py.File(h5_path, 'r') as f:
        for group_name in f.keys():
            data = f[group_name]['data']
            data = np.array(data)
            data = torch.from_numpy(data)
            print(data.shape)
            data_list.append(data)
    all_data = torch.cat(data_list, dim=0)
    dataset = HDF5Dataset(all_data)
    dataloader = DataLoader(dataset, batch_size=5096, shuffle=True, pin_memory=True)
    print(len(dataloader))
    device = torch.device("cuda:0")

        
    unet = Unet1d(dim=args.init_dim, T=1000, channels=channels, dim_mults=args.dim_mults).to(device)
    unet = nn.DataParallel(unet, device_ids=[0, 1, 2, 3]) # Please adjust the GPU IDs according to the available GPU resources.
    diffuser = SVDiffusion(time_steps=1000, unet=unet, w=2, device=device)
    diffuser_trainer = Trainer(channels=channels, epochs=args.epochs, 
                                diffuser=diffuser, train_loader=dataloader, device=device)
    
    diffuser_trainer.train()

        



        
