import numpy as np
import time
import torch
from tqdm.auto import tqdm

def train_step(model, initializer, batch, noise_scheduler, optim, lr_scheduler, device):
    optim.zero_grad() 
    x_1 = batch['label'].to(device)
    c = batch['feature'].to(device) 
    z = torch.randn_like(x_1).to(device)

    x_0 = initializer(z, x_1[:, :initializer.mask_size], c)

    t = torch.randint(0, noise_scheduler.num_train_timesteps, (x_1.shape[0],), device=device).long()

    x_n = noise_scheduler.add_noise(x_1, x_0, t)

    pred = model(x_n, t, c)
    loss = torch.pow(pred - x_0, 2).mean() 

    # optimizer step
    loss.backward() # backward
    optim.step() # update
    lr_scheduler.step()
    
    return loss.item()

def test_step(model, initializer, batch, noise_scheduler, device):
    x_1 = batch['label'].to(device)
    c = batch['feature'].to(device) 
    z = torch.randn_like(x_1).to(device)

    x_0 = initializer(z, x_1[:, :initializer.mask_size], c)

    t = torch.randint(0, noise_scheduler.num_train_timesteps, (x_1.shape[0],), device=device).long()

    x_n = noise_scheduler.add_noise(x_1, x_0, t)

    pred = model(x_n, t, c)
    loss = torch.pow(pred - x_0, 2).mean() 

    return loss.item()

def train(model, initializer, train_dataloader, test_dataloader, noise_scheduler, optim, lr_scheduler, device, config, validation=True):
    start_time = time.time()
    for i in range(config.num_epochs):
        loss = 0.0
        val_loss = 0.0
        model.train()  # set the model to training mode
        initializer.train()  # set the initializer to training mode
        for batch in train_dataloader:
            loss = train_step(model, initializer, batch, noise_scheduler, optim, lr_scheduler, device)
            loss += loss  # accumulate loss
        if validation:
            model.eval()  # set the model to evaluation mode
            initializer.eval()  # set the initializer to evaluation mode
            for batch in test_dataloader:
                val_loss = test_step(model, initializer, batch, noise_scheduler, device)
                val_loss += val_loss  # accumulate validation loss
            
            # log loss
        if (i+1) % config.print_every == 0:
            elapsed = time.time() - start_time
            print('| iter {:6d} | {:5.2f} ms/step | loss {:8.6f} | val_loss {:8.6f} | lr {:.6f} |' 
                .format(i+1, elapsed*1000/config.print_every, loss, val_loss, lr_scheduler.get_last_lr()[0])) 
            start_time = time.time()

def generate_samples(dataset, model, initializer, noise_scheduler, device, config, num_samples=10, return_records=False):
    model.eval()  # set the model to evaluation mode
    x_test = dataset.test_data.label.to(device)
    c_test = dataset.test_data.feature.to(device)  # repeat the conditioning for all samples
    batch_size = x_test.shape[0]  # batch size

    x = torch.randn((batch_size, config.sequence_length), dtype=torch.float32, device=device)
    x = initializer(x, x_test[:, :initializer.mask_size].to(device=device), c_test)  # initialize the samples
    records = []
    with torch.no_grad():
        for t in tqdm(noise_scheduler.timesteps):
            tb = t.repeat(x.shape[0]).to(device)
            pred = model(x, tb, c_test) # shape: (batch_size, 1, sample_size)
            x = noise_scheduler.step(model_output=pred, timestep=t, sample=x).prev_sample
            x = initializer(x, x_test[:, :initializer.mask_size].to(device=device), c_test)
            if t % 100 == 0:
                records.append(dataset.label_transformation.inverse_transform(x.cpu()))

    x_test = dataset.label_transformation.inverse_transform(dataset.test_data.label.cpu())
    x_pred = dataset.label_transformation.inverse_transform(x.cpu())
    if return_records:
        return x_test, x_pred, np.array(records)
    else:
        return x_test, x_pred
