from pytorch_lightning.loggers.wandb import WandbLogger
import torch
from src.diffusion.SDE import VPSDE, GenPinnedBrownSDE
from src.constants import PROJECT_PATH, RAW_DATA_PATH
import os
from src.models.model_toy_2D import ToyDiffusion
from src.denoisers.MLP import MLPDenoiser
import matplotlib.pyplot as plt
from hydra import initialize, compose
from pytorch_lightning import Trainer, seed_everything
from src.data.datamodules import get_complicated_dataset
import wandb
import argparse


with initialize(version_base=None, config_path=f"../configs/"):  
    cfg = compose(config_name='config_toy_2D.yaml')

parser = argparse.ArgumentParser(description="Your script description.")
parser.add_argument("--dim", type=int)
parser.add_argument("--mix", type=int)
parser.add_argument("--hid_dim", type=int)
parser.add_argument("--std", type=str)

args = parser.parse_args()

wandb_logger = WandbLogger(project='first_run', config=args)
dim = args.dim
mix = args.mix

if args.mix==10:
    sphere = "sphere_"
else:
    sphere= ""

cfg.model.data_dim= dim
cfg.model.hid_dim = args.hid_dim
if args.std == "vp":
    sde =  VPSDE()
else:
    sde = GenPinnedBrownSDE()

denoiser = MLPDenoiser(in_dim=dim+1, hid_dim=cfg.model.hid_dim, num_hid_layers=cfg.model.num_hid_layers, dropout=cfg.model.dropout, out_dim=dim)
diffuser = ToyDiffusion(cfg, denoiser, sde)
diffuser.to('cpu')


# Instantiate Trainer
trainer = Trainer(max_epochs=50, logger=wandb_logger)

# Train model
dataset = get_complicated_dataset(sphere=sphere, dim=dim, mix=mix)
trainer.fit(diffuser, dataset)

xs = []
for i in range(50):
    xs.append(diffuser.sample(eta=1.))

x_new = torch.cat(xs, dim=0)

wandb_logger.experiment.log({'x_new': wandb.Table(data=x_new.numpy(), columns=[f'col_{i}' for i in range(x_new.numpy().shape[1])])})
