import math
import numpy as np
import torch
import tqdm

from torch.utils.data import Dataset, DataLoader
from denoiser import NavierStokesDenoiser
from typing import Optional


def train(
    save_path: str,
    denoiser: NavierStokesDenoiser,
    num_epochs: int,
    learning_rate: float,
    train_loader: DataLoader,
    val_loader: DataLoader,
    input_denoiser_path: Optional[str] = None,
    max_grad_norm: float = 1.0,
):
    """
    Train a denoiser to estimate E[hat{z}^{k+1} | hat{z}^{k+1}_{t}, hat{x}^{k}].
    Input(s):
        - save_path (str): path to the folder where the trained denoiser is saved.
        - denoiser (NavierStokesDenoiser): the denoiser to train.
        - num_epochs (int): number of epochs.
        - learning_rate (float): learning rate for the optimizer.
        - train_loader (Dataloader): torch dataloader for the training data.
        - val_loader (Dataloader): torch dataloader for the validation data.
        - input_model_path (Optional[str]): path to a trained denoiser in order to continue the training.
        - max_grad_norm (float): maximum norm for the gradients.
    """
    # Load the denoiser if a path is provided
    if input_denoiser_path is not None:
        denoiser = torch.load(input_denoiser_path, weights_only=False)

    # Define the device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    denoiser.to(device)

    # Define the optimizer and the scheduler
    optimizer = torch.optim.AdamW(
        params=denoiser.parameters(),
        lr=learning_rate,
    )
    lr_lambda = lambda t: (1.0 + math.cos(math.pi * t / num_epochs)) / 2.0  # noqa: E731
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Loop on the number of epochs
    best_val_loss = +np.inf
    uniform = torch.distributions.Uniform(0.0, 1.0)
    for epoch in range(num_epochs):
        # Instanciate a runing loss for the epoch
        num_samples = 0
        running_loss = 0.0

        # Switch the denoiser to training mode
        denoiser.train()

        # Loop on batches
        progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        for hat_x_k, hat_z_kp1 in progress_bar:
            # Get the current batch size and draw times
            current_batch_size = hat_x_k.size(0)
            t = uniform.sample((current_batch_size,))

            # Switch tensors to device
            t = t.to(device)
            hat_x_k = hat_x_k.to(device)
            hat_z_kp1 = hat_z_kp1.to(device)

            # Clean gradients
            optimizer.zero_grad()

            # Compute the loss for the current batch
            loss = denoiser.loss(z_kp1=hat_z_kp1, t=t, x_k=hat_x_k)

            # Do backpropagation
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=denoiser.parameters(),
                max_norm=max_grad_norm,
            )
            optimizer.step()

            # Update the running loss
            num_samples += current_batch_size
            running_loss += loss.item() * current_batch_size

            # Update the tqdm bar
            avg_loss = running_loss / num_samples
            progress_bar.set_postfix(loss=avg_loss)

        # Update the learning rate
        scheduler.step()

        # Evaluate the denoiser on the validation dataset
        num_samples_val = 0
        running_loss_val = 0.0

        # Switch denoiser to eval mode
        denoiser.eval()

        # Loop on batches
        for hat_x_k, hat_z_kp1 in val_loader:
            # Get the batch size and draw times
            current_batch_size = hat_x_k.size(0)
            t = uniform.sample((current_batch_size,))

            # Switch tensors to device
            t = t.to(device)
            hat_x_k = hat_x_k.to(device)
            hat_z_kp1 = hat_z_kp1.to(device)

            # Compute the loss
            with torch.no_grad():
                val_loss = denoiser.loss(z_kp1=hat_z_kp1, t=t, x_k=hat_x_k)

            # Update the running loss
            num_samples_val += current_batch_size
            running_loss_val += val_loss.item() * current_batch_size

        # Print epoch loss
        print(
            f"Epoch {epoch + 1} - Average validation loss: {(running_loss_val / num_samples_val):.3f}"
        )

        # Save the model if the validation loss is lower
        validation_loss = running_loss_val / num_samples_val
        if validation_loss <= best_val_loss:
            best_val_loss = validation_loss
            if save_path[-1] == "/":
                best_save_path = save_path + str("best_trained_denoiser.pt")
                best_save_path_dict = save_path + str("best_trained_denoiser_dict.pt")
            else:
                best_save_path = save_path + str("/best_trained_denoiser.pt")
                best_save_path_dict = save_path + str("/best_trained_denoiser_dict.pt")
            torch.save(denoiser, best_save_path)
            torch.save(denoiser.state_dict(), best_save_path_dict)

    # Save the last model
    if save_path[-1] == "/":
        last_model_path = save_path + str("last_epoch_model.pt")
        last_model_path_dict = save_path + str("last_epoch_model_dict.pt")
    else:
        last_model_path = save_path + str("/last_epoch_model.pt")
        last_model_path_dict = save_path + str("/last_epoch_model_dict.pt")
    torch.save(denoiser, last_model_path)
    torch.save(denoiser.state_dict(), last_model_path_dict)


if __name__ == "__main__":
    # Training parameters
    rel_save_path = "./../save/"
    rel_input_denoiser_path = None
    num_epochs = 20
    learning_rate = 2e-4

    # Instanciate a denoiser
    denoiser = NavierStokesDenoiser()

    # Create datasets and dataloaders
    class TrainingDataset(Dataset):
        def __init__(self):
            super().__init__()
            self.mean_x = torch.load("./../data/stats/mean_x.pt").to(dtype=torch.float32)
            self.std_x = torch.load("./../data/stats/std_x.pt").to(dtype=torch.float32)
            self.std_z = torch.load("./../data/stats/std_z.pt").to(dtype=torch.float32)
            self.data = torch.load("./../data/training_samples.pt").to(dtype=torch.float32)

        def __len__(self):
            return self.data.shape[0]

        def __getitem__(self, idx):
            x_km1, x_k, x_kp1 = self.data[idx]
            z_kp1 = x_kp1 - x_k
            hat_x_k = (x_k - self.mean_x) / self.std_x
            hat_x_km1 = (x_km1 - self.mean_x) / self.std_x
            hat_x = torch.cat((hat_x_km1, hat_x_k), dim=0)
            hat_z_kp1 = z_kp1 / self.std_z
            return hat_x, hat_z_kp1

    class ValidationDataset(Dataset):
        def __init__(self):
            super().__init__()
            self.mean_x = torch.load("./../data/stats/mean_x.pt").to(dtype=torch.float32)
            self.std_x = torch.load("./../data/stats/std_x.pt").to(dtype=torch.float32)
            self.std_z = torch.load("./../data/stats/std_z.pt").to(dtype=torch.float32)
            self.data = torch.load("./../data/validation_samples.pt").to(dtype=torch.float32)

        def __len__(self):
            return self.data.shape[0]

        def __getitem__(self, idx):
            x_km1, x_k, x_kp1 = self.data[idx]
            z_kp1 = x_kp1 - x_k
            hat_x_k = (x_k - self.mean_x) / self.std_x
            hat_x_km1 = (x_km1 - self.mean_x) / self.std_x
            hat_x = torch.cat((hat_x_km1, hat_x_k), dim=0)
            hat_z_kp1 = z_kp1 / self.std_z
            return hat_x, hat_z_kp1

    train_loader = DataLoader(
        dataset=TrainingDataset(),
        batch_size=32,
        num_workers=4,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
    )

    val_loader = DataLoader(
        dataset=ValidationDataset(),
        batch_size=32,
        num_workers=4,
        shuffle=False,
        drop_last=False,
        pin_memory=True,
    )

    train(
        save_path=rel_save_path,
        denoiser=denoiser,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        train_loader=train_loader,
        val_loader=val_loader,
        input_denoiser_path=rel_input_denoiser_path,
    )
