import argparse
import torch
import os
import math

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt

from data.gaussians import sample_gaussian
import pytorch_lightning as pl
from torch import nn

from models.score_matching import VPSDE, MLPDenoiser, ScoreMatching

parser = argparse.ArgumentParser(description='Run Gaussian Flow Matching experiment')
parser.add_argument('--s', type=float, help='s value for the experiment')
args = parser.parse_args()

args.s = 3


train_data =  sample_gaussian(10000, 5, sigma= 1, dim = 2) 
val_data = sample_gaussian(10000, 5, sigma= 1, dim = 2) 
print(train_data[:10])
train_loader =DataLoader(train_data, batch_size=64)
val_loader =DataLoader(train_data, batch_size=64)


denoiser = MLPDenoiser(in_dim=3, hid_dim=64, out_dim=2, num_hid_layers=2, dropout=0.1, activation=F.relu)
sde = VPSDE(beta_min=0.01, beta_max=20., N=1000, denoiser=denoiser)

score_matching = ScoreMatching(denoiser, sde)


# Trainer
trainer = pl.Trainer(
    max_epochs=50,
    check_val_every_n_epoch=20,
    accelerator= "cpu",
    devices=1,
)

# Train the model
trainer.fit(score_matching, train_loader, val_loader)

print(score_matching.run_sampler_backward(torch.randn(64,2))[-1].mean())
