from models.datasets.circle_multimodal import CircleMultimodal
from models.datasets.mnist_digits import MNISTDigitsDataset
from models.trained2d import ScoreModel2D
from models.trained_images import ScoreModelImage, TrainedImage
from models.base_trained import WEIGHT_DIR

from torch.utils.data import DataLoader, RandomSampler, random_split
from torch.optim import Adam
import torch
from torch import sqrt, exp

import os
import math

import matplotlib.pyplot as plt

def train(dataloader, model, optimizer, n_time_samples, device, T,  schedule_g):
    
    loss_fn = torch.nn.MSELoss()
    model.train()
    losses = []
    steps = len(dataloader)

    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        batch = batch.to(device)

        ts = torch.rand((n_time_samples,1), device = device)*T + 1e-4
        gs = schedule_g(ts)
        total_loss = 0.0  
        for g in gs:
            n_ = torch.randn(batch.shape, device = device)
            perturbed = torch.exp(-g/2) * batch + sqrt(1 - torch.exp(-g)) * n_
            predicted = model(perturbed, g)
            target = -n_
            total_loss += 1/n_time_samples * loss_fn(predicted, target)

        losses.append(total_loss.item())
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

        if step % 10 == 0:
            print(f"loss: {total_loss.item():>7f}  [{step:>5d}/{steps:>5d}]")

    return torch.Tensor(losses)

def save_checkpoint(model, optimizer, epoch, loss):
    filename = os.path.join(WEIGHT_DIR, 'checkpoints', model.get_filename() + '.ckpt')
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss
    }
    torch.save(checkpoint, filename)

def cosine_schedule_g(t, delta = 1e-4):
    if not isinstance(t, torch.Tensor):
        t = torch.tensor(t)
    args = (t+delta)/(1+delta)*torch.pi/2
    c = math.cos((delta)/(1+delta)*torch.pi/2)**2
    return -torch.log(torch.cos(args)**2/c)

def cosine_alpha(t, delta = 1e-4):
    return torch.exp(-cosine_schedule_g(t, delta))


def plot_image_grid(images, dir, name, cols=5, figsize=(10, 10), cmap=None):
    """
    Plots a list of images in a grid.

    Parameters:
    - images: list of image arrays (H x W or H x W x C)
    - cols: number of columns in the grid
    - titles: optional list of titles for each image
    - figsize: size of the entire figure
    - cmap: colormap for grayscale images (e.g., 'gray')
    """
    n_images = len(images)
    rows = math.ceil(n_images / cols)

    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten()

    for i in range(len(axes)):
        ax = axes[i]
        ax.axis('off')
        if i < n_images:
            ax.imshow(images[i], cmap=cmap)

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.savefig(os.path.join(dir, name)) 
    plt.clf()

if __name__ == '__main__':
    name = 'mnist_digits'
    dataset = MNISTDigitsDataset()
    sampler = TrainedImage(
        N = 4000,
        T = 0.9,
        image_dim= 28,
        model_name = name,
        dim_mults=(2,4,8),
        channels=1,
        noise_schedule='cosine'
    )

    model = sampler.score_model

    T = 1
    model.train()

    epochs = 30
    batch_size = 64

    train_dataloader = DataLoader(dataset, batch_size = batch_size)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = Adam(model.parameters(), lr = 2e-5)

    n_t_samples = 5
    schedule_g = cosine_schedule_g
    steps = len(train_dataloader)
    losses = torch.zeros((epochs, steps))

    for epoch in range(epochs):
        print("EPOCH: ", epoch + 1)
        print('==============================')
        losses[epoch] = train(train_dataloader, 
                              model, optimizer, n_t_samples, device, T, schedule_g = schedule_g)
        save_checkpoint(model, optimizer, epoch, losses[epoch, -1])
        model.save_weights()
        print("Generating a few samples: ")
        traj = sampler.sample_reverse_trajectory(n_samples=1, n_noise_realizations=10)
        image_list = traj[0, :, 0, :]
        image_list = [sampler.vec_to_image(im).numpy()[0] for im in image_list]
        plot_image_grid(images=image_list, dir = './image/mnist', name = f'training-epoch-{epoch+1}.png')
        print(f"Results saved in: image/mnist/training-epoch-{epoch+1}.png")
        print('\n')

    model.save_weights()