import argparse
import torch
import os
import math

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt

from data.gaussians import sample_gaussian

from models.score_matching import ScoreMatching, training


parser = argparse.ArgumentParser(description='Run Gaussian Flow Matching experiment')
parser.add_argument('--s', type=float, help='s value for the experiment')
args = parser.parse_args()

args.s = 3


dset =  sample_gaussian(10000, 5, sigma= args.s, dim = 2) 
print(set[:5])
val_data = sample_gaussian(10000, 5, sigma= args.s, dim = 2) 


training_loader =DataLoader(dset, batch_size=32)
val_loader =DataLoader(dset, batch_size=32)

# score_network takes input of 2 dimension and returns the output of the same size
score_network = torch.nn.Sequential(
    torch.nn.Linear(2, 64),
    torch.nn.LogSigmoid(),
    torch.nn.Linear(64, 64),
    torch.nn.LogSigmoid(),
    torch.nn.Linear(64, 64),
    torch.nn.LogSigmoid(),
    torch.nn.Linear(64, 2),
)

from torch import jacrev, vmap

def calc_loss(score_network: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
    # x: (batch_size, 2) is the training data
    score = score_network(x)  # score: (batch_size, 2)
    
    # first term: half of the squared norm
    term1 = torch.linalg.norm(score, dim=-1) ** 2 * 0.5
    
    # second term: trace of the Jacobian
    jac = vmap(jacrev(score_network))(x)  # (batch_size, 2, 2)
    term2 = torch.einsum("bii->b", jac)  # compute the trace
    return (term1 + term2).mean()

# start the training loop
import time
opt = torch.optim.Adam(score_network.parameters(), lr=3e-4)
dloader = DataLoader(dset, batch_size=32, shuffle=True)
t0 = time.time()
for i_epoch in range(500):
    total_loss = 0
    for data in dloader:
        opt.zero_grad()

        # training step
        loss = calc_loss(score_network, data)
        loss.backward()
        opt.step()
        
        # running stats
        total_loss = total_loss + loss.detach().item() * data.shape[0]
    
    # print the training stats
    if i_epoch % 50 == 0:
        print(f"{i_epoch} ({time.time() - t0}s): {total_loss / len(dset)}")


def generate_samples(score_net: torch.nn.Module, nsamples: int, eps: float = 0.001, nsteps: int = 1000) -> torch.Tensor:
    # generate samples using Langevin MCMC
    # x0: (sample_size, nch)
    x0 = torch.rand((nsamples, 2)) * 2 - 1
    for i in range(nsteps):
        z = torch.randn_like(x0)
        x0 = x0 + eps * score_net(x0) + (2 * eps) ** 0.5 * z
    return x0

samples = generate_samples(score_network, 1000).detach()
    # Save and display the plot
plt.plot(samples[:, 0], samples[:, 1], 'C1.')
plt.savefig('2d_plot.png', bbox_inches='tight')
