import wandb
from common.pipeline import DEFAULT_PARTITION_DESCRIPTION, load_and_preproc_data
from common.data import NumpyDataset, SlidingDataset
from model import TransformerDenoiser

from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D
from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, SequentialLR
from tqdm import tqdm
# from common.pipeline import load_dataset_from_str
from common.preproc import minmaxscale

import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append('../')
from data import load_dataset_from_str 


def main():
    parser = argparse.ArgumentParser(
                        prog='ProgramName',
                        description='What the program does')

    parser.add_argument('dataset', choices=['traffic', 'stock', 'energy', 'solar', 'electricity', 'atc-first-day', 'geolife-25', 'geolife-100', 'porto-25', 'porto-100'])
    parser.add_argument('seq_len', choices=[24, 32, 64, 128, 256, 512, 1024], type=int)

    args = parser.parse_args()

    return evalatue_parameters(args)


def evalatue_parameters(args):
    dataset_str = args.dataset
    seq_len = args.seq_len
    num_workers = 20
    """
    batch_size = 64
    timesteps = 100
    hidden_dim = 16
    objective = 'pred_noise'
    beta_schedule = 'cosine'
    learning_rate = 1e-3
    """
    batch_size = 32
    timesteps = 100
    hidden_dim = 64
    objective = 'pred_noise'
    beta_schedule = 'cosine'
    n_heads = 4
    n_layers = 2
    learning_rate = 2.9e-4
    timesteps = 100
    epochs = 100
    final_predictor = 'mlp'  # or 'linear'
    num_epochs = epochs
    number_warmup_epochs = 10
    max_plot_dim = 8

    exp_desc = f'unconditional-tddpm-{dataset_str}-{seq_len}'

    dataset = load_dataset_from_str(dataset_str, prefix='../data/')

    if isinstance(dataset, np.ndarray) and len(dataset.shape) == 2:
        dataset = np.expand_dims(dataset, axis=0)

    train_new = minmaxscale(dataset)

    # import pdb; pdb.set_trace()
    n_dim = train_new[0][0].shape[-1]
    plot_dims = min(max_plot_dim, n_dim)

    dataset = SlidingDataset(train_new, seq_len)

    samples_per_epoch = min(10000, len(dataset))

    sampler = RandomSampler(dataset, replacement=False, num_samples=samples_per_epoch)

    train_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            sampler=sampler)

    model = TransformerDenoiser(
        dim_data=n_dim, 
        dim_embedding=hidden_dim,
        seq_len=seq_len,
        n_heads=n_heads,
        n_layers=n_layers
    )

    diffusion = GaussianDiffusion1D(
        model,
        seq_length = seq_len,
        timesteps = timesteps,
        objective = objective,
        beta_schedule = beta_schedule,
    )

    diffusion.cuda()

    adam_betas = (0.9, 0.99)

    optimizer = torch.optim.Adam(diffusion.parameters(), learning_rate, adam_betas)

    # https://stackoverflow.com/questions/65343377/adam-optimizer-with-warmup-on-pytorch/65344276#65344276
    train_scheduler = CosineAnnealingLR(optimizer, num_epochs)

    def warmup(current_step: int):
        return 1 / (10 ** (float(number_warmup_epochs - current_step)))

    warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup)

    lr_scheduler = SequentialLR(optimizer, [warmup_scheduler, train_scheduler], [number_warmup_epochs])

    def generate_samples(n_samples=100, sample_batch_size=batch_size):    
        diffusion.eval()

        all_samples = []

        n_batches = n_samples // sample_batch_size
        remainder = n_samples % sample_batch_size

        def sample(n_samples):
            samples = diffusion.sample(n_samples)
            # Convert to numpy
            samples = samples.cpu().numpy()

            return samples

        for i in range(n_batches):
            all_samples.append(sample(sample_batch_size))

        all_samples.append(sample(remainder))

        all_samples = np.concatenate(all_samples, axis=0)
        # print(all_samples.shape)
        
        diffusion.train()
        
        f, axs = plt.subplots(figsize=(6*plot_dims, 4), ncols=plot_dims, sharey=True, sharex=True, constrained_layout=True)
        
        if n_dim == 1:
            axs.plot(all_samples[:, 0, :].T)
        else:
            for i, ax in enumerate(axs):
                ax.plot(all_samples[:, i, :].T)

        return np.transpose(all_samples, (0, 2, 1)), f
    
    progress_bar = tqdm(range(epochs))

    for epoch in progress_bar:
        epoch_loss = 0
        for x in tqdm(train_loader):
            x = torch.transpose(x, 1, 2)
            x = x.cuda()

            loss, _ = diffusion(x)
            epoch_loss += loss.item()

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        # epoch_loss
        epoch_loss /= len(train_loader)

        to_log = {'epoch_loss': epoch_loss}

        """
        if epoch % 10 == 0:
            _, f = generate_samples(n_samples=10)
            to_log['samples'] = wandb.Image(f)
        """

        lr_scheduler.step()
        
        # wandb.log(to_log)

        # writer.add_scalar('Loss/train', epoch_loss, epoch + 1)

    torch.save(
        {
        'model_state_dict': model.state_dict(), 
        'diffusion_state_dict': diffusion.state_dict(),
        'optim_state_dict': optimizer.state_dict()
        },
        f"model_weights/{exp_desc}.pth"
    )

    n_samples = min(len(dataset), 10000)
    all_samples, _ = generate_samples(n_samples=n_samples)
    np.save(f'_out/{exp_desc}.npy', all_samples)


if __name__ == '__main__':
    main()
