import torch
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D, Trainer1D, Dataset1D
import time

def Train(data, batch, n, index, train_step):
    model = Unet1D(
        dim=64,
        dim_mults=(1, 2, 4, 8),
        channels=1
    )
    diffusion = GaussianDiffusion1D(
        model,
        seq_length=128,
        timesteps=1000,  # 1000
        objective='pred_v'
    )
    dataset = Dataset1D(data)

    trainer = Trainer1D(
        diffusion,
        dataset=dataset,
        train_batch_size=32,
        train_lr=8e-5,
        train_num_steps=train_step,         
        gradient_accumulate_every=2,   
        ema_decay=0.995,                
        amp=True,                 
    )

    train_start = time.time()
    trainer.train()
    train_end = time.time()
    TrainTimeUsage = train_end - train_start

    torch.save(diffusion.state_dict(), f'diffusion_{n}_{index}_{batch}_{train_step}.pth')

    res = 'n = {}, steps = {}, index = {}, batch = {}, TrainTimeUsage = {}\n'.format(n, train_step, index, batch, TrainTimeUsage)

    with open(f'TrainingTimeUsage.txt', 'a') as file: 
        file.write(res)