from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial
import matplotlib.pyplot as plt
import wandb
import numpy as np
import argparse

from denoising_diffusion_pytorch import GaussianDiffusion1D

from model import TransformerDenoiser
from dataset import GeoLife25, GeoLife100, Porto25, Porto100


def main():
    parser = argparse.ArgumentParser(
                        prog='ProgramName',
                        description='What the program does')

    parser.add_argument('dataset', choices=['porto-25', 'porto-100', 'geolife-25', 'geolife-100'])
    args = parser.parse_args()

    num_workers = 10
    
    dataset_str = args.dataset
    n_dim = 2
    seq_len = 128

    batch_size = 32
    timesteps = 100
    hidden_dim = 64
    objective = 'pred_noise'
    reconstruction_loss_weight = 0
    beta_schedule = 'cosine'
    n_heads = 4
    n_layers = 2
    learning_rate = 2.9e-4
    timesteps = 100
    epochs = 100
    final_predictor = 'mlp'  # or 'linear'
    num_epochs = epochs
    number_warmup_epochs = 10
    max_plot_dim = 8
    adam_betas = (0.9, 0.99)


    def test_conditioned_generation(idx, n_samples=100):
        sample_batch_size = batch_size
        
        query_heatmap, something = dataset[idx]
        
        diffusion.eval()

        all_samples = []

        n_batches = n_samples // sample_batch_size
        remainder = n_samples % sample_batch_size

        def sample(n_samples):
            condition = query_heatmap.unsqueeze(0).expand(n_samples, 1, 64, 64).float().cuda()
            # Sample `n_samples` number of samples from the model
            samples = diffusion.sample_conditional(condition, n_samples)
            # Convert to numpy
            samples = samples.cpu().numpy()
            # Remove channel dimension and compatibility padding

            return samples

        for i in range(n_batches):
            all_samples.append(sample(sample_batch_size))

        all_samples.append(sample(remainder))

        all_samples = np.concatenate(all_samples, axis=0)
        # print(all_samples.shape)
        
        diffusion.train()
        
        f, (ax1, ax2, ax3) = plt.subplots(figsize=(12, 4), ncols=3, sharey=True, sharex=True, constrained_layout=True)
        ax1.imshow(np.log(query_heatmap[0]), extent=[0, 1, 0, 1])
        ax1.plot(something[:, 1], 1-something[:, 0])
        ax2.plot(all_samples[:, 1, :].T, 1-all_samples[:, 0, :].T)
        ax3.imshow(np.log(query_heatmap[0]), extent=[0, 1, 0, 1])
        ax3.plot(all_samples[:, 1, :].T, 1-all_samples[:, 0, :].T, alpha=.50)
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)
        return f


    if dataset_str == 'porto-25':
        dataset = Porto25()
    elif dataset_str == 'porto-100':
        dataset = Porto100()
    elif dataset_str == 'geolife-25':
        dataset = GeoLife25()
    elif dataset_str == 'geolife-100':
        dataset = GeoLife100()

    # Create data loader for the data
    train_loader = DataLoader(dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=True)
    
    model = TransformerDenoiser(n_dim, hidden_dim, seq_len=seq_len, n_heads=n_heads, n_layers=n_layers)

    diffusion = GaussianDiffusion1D(
        model,
        seq_length = seq_len,
        timesteps = timesteps,
        objective = objective,
        beta_schedule = beta_schedule,
        reconstruction_loss_weight = reconstruction_loss_weight
    )

    diffusion.cuda()

    optimizer = torch.optim.Adam(diffusion.parameters(), learning_rate, adam_betas)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    progress_bar = tqdm(range(epochs), position=0, desc='Epochs:')

    wandb.init(
        project="tddpm",
        group=f'{dataset_str}',
        reinit=True
    )

    for epoch in progress_bar:
        epoch_loss = 0
        epoch_recon_loss = 0
        epoch_pbar = tqdm(total=len(dataset), position=1, desc='Iterating over train set', leave=False)
        for heatmap, x in train_loader:
            heatmap = heatmap.cuda()

            x = torch.transpose(x, 1, 2)
            x = x.cuda()

            loss, loss_d = diffusion(x, conditional = heatmap)
            loss.backward()

            epoch_loss += loss_d['ddpm_loss']

            optimizer.step()
            optimizer.zero_grad()

            epoch_pbar.update(len(x))

        epoch_loss /= len(train_loader)
        epoch_recon_loss /= len(train_loader)

        results = {
            'epoch_loss': epoch_loss, 
            'lr': lr_scheduler.get_lr()[0]
        }
        
        if epoch % 10 == 0:
            for i in [9359, 4204, 3604, 3076, 9904, 382, 352, 115, 272, 221, 777, 991, 602, 947, 114]:
                results[f'samples_{i}'] = test_conditioned_generation(i, n_samples=100)

        wandb.log(results)

        lr_scheduler.step()


    diffusion.cpu()

    torch.save(
        {
        'model_state_dict': model.state_dict(), 
        'diffusion_state_dict': diffusion.state_dict(),
        'optim_state_dict': optimizer.state_dict(),
        },
        f"model_weights/conditional-{dataset_str}.pth"
    )

if __name__ == '__main__':
    main()
