import numpy as np
from distributions import GaussianMixtureGenerator
import tqdm
from models import MLP
from losses import score_matching_loss
import torch
from visualizations import (
    visualize_loss,
)
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--data-type",
        help="Please indicate the type of gaussians to generate",
        required=True,
        type=str,
        choices=["2-mode", "4-mode"],
    )

    parser.add_argument(
        "--save",
        help="Please indicate if you would like to save the model",
        required=True,
        type=int,
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Adapative sigma for easier training, fix sigma for final part of training.
    sigma_min = 4.0
    sigma_max = 14.0
    sigma = sigma_max
    sigma_epochs = 8000

    args = parser.parse_args()

    if args.data_type == "2-mode":
        data_dim = 1
        n_clouds = 2
    else:
        data_dim = 2
        n_clouds = 4

    if args.data_type == "2-mode":
        print("Training a score matching model to a Mixture of 2 Gaussians")
        sample_dist = GaussianMixtureGenerator(
            2, 1.5, 0.2, n_dims=data_dim, probs=torch.Tensor([0.1, 0.9])
        ).generate
    else:
        print("Training a score matching model to a Mixture of 4 Gaussians")
        sample_dist = GaussianMixtureGenerator(
            4, 1.5, 0.2, n_dims=data_dim, probs=torch.Tensor([0.05, 0.25, 0.45, 0.25])
        ).generate

    loss_fn = score_matching_loss
    model_choice = MLP
    sampler = "sde"
    epochs = 12000
    lr = 1e-4
    model = model_choice(data_dim, 512).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    losses = []
    tqdm_epochs = tqdm.trange(epochs)

    points = np.zeros(100)
    for epoch in tqdm_epochs:
        optimizer.zero_grad()
        points = torch.tensor(sample_dist(1024), dtype=torch.float32, device=device)
        if epoch < sigma_epochs:
            sigma = sigma_min * (sigma_max / sigma_min) ** (
                1 - (epoch / sigma_epochs) ** 2
            )

        loss = loss_fn(model, points, sigma=sigma)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        tqdm_epochs.set_description(f"Average Loss: {loss.item():.5f}, sigma:{sigma}")

    print("Visualizing Loss")
    visualize_loss(losses)

    if args.save:
        print("Saving Model")
        torch.save(
            model.state_dict(),
            f"../steering_diffusion/model/fk_failed_{data_dim}d_{n_clouds}clouds.pt",
        )
