import os
import sys

import pennylane as qml
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from data_utils.aae_dataset import FractalDB_Dataset, MNIST_AAE_Dataset, MNISTDataset
from data_utils.plot import plot_2d
from loss import FidLossDotProd, FidLossMSE
from models.state_generators import StateGenerator
from utils import add_noise, append_log, norm_image, resize

if len(sys.argv) != 2:
    print(f"Usage: python {sys.argv[0]} <path-to-configuration>")
    sys.exit(1)


config_file_path = sys.argv[1]


LOSS_FN = {"DotProd": FidLossDotProd, "MSE": FidLossMSE}


def seed_everything(seed):
    import random

    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def train(config, loader: DataLoader):
    loss_path = os.path.join(config.checkpoint.logs, "loss.txt")
    # FIXME: rewrite log if exist, too ugly

    if os.path.exists(loss_path):
        f = open(loss_path, "w")
        f.close()

    superencoder = StateGenerator(config).to(config.device)

    try:
        if config.resume_from_ckpt:
            superencoder.load(config.checkpoint.save_path)
    except:
        pass

    loss_args = config.get("state_generator").get("loss_args", {})
    loss_name = config.state_generator.loss
    loss_fn = LOSS_FN[loss_name](superencoder.qc, **loss_args).to(config.device)

    opt_name = config.optimizer.get("name", "Adam")
    opt_args = config.optimizer.get("args", config.optimizer)
    optimizer_cls = getattr(torch.optim, opt_name)
    optimizer = optimizer_cls(superencoder.parameters(), **opt_args)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.n_epochs
    )

    writer = SummaryWriter(os.path.join(config.checkpoint.logs, "tensorboard"))

    for epoch in range(config.n_epochs):
        batch_idx = 0
        epoch_loss_sum = 0
        with tqdm(loader, leave=False) as bar:
            for batch in bar:
                # if batch_idx == 69:  # loss become nan at this iteration
                #     print()
                images = batch["images"]

                images = resize(images, config.state_generator.aae_encoder.n_qubits).to(
                    config.device
                )
                if (
                    config.noise_factor
                    and torch.rand(1).item() < config.noisy_probability
                ):
                    images = add_noise(
                        images, config.noise_factor, config.device
                    )  # adding noise make training sample more dynamic

                images = norm_image(images)
                if (
                    images.isnan().any()
                ):  # some inputs contain nan after 'resize_and_norm', no idea why, in current config, this happen when batch_idx==69
                    continue
                # plot_2d(images[0], figname=f"input{batch_idx}.pdf")

                pred = superencoder(images)

                loss = loss_fn(pred, images)

                optimizer.zero_grad()
                loss.backward()

                # found loss become nan after some time when training, try gradient clipping
                torch.nn.utils.clip_grad.clip_grad_norm_(superencoder.parameters(), 1.0)
                optimizer.step()
                batch_idx += 1
                epoch_loss_sum += loss.item()

                bar.set_postfix(loss=loss.item())

                writer.add_scalar(
                    "Loss/Step",
                    loss.item(),
                    batch_idx + epoch * config.dataloader.batch_size,
                )
                append_log(loss_path, loss.item())

        scheduler.step()

        writer.add_scalar("Loss/Epoch", epoch_loss_sum / (batch_idx + 1), epoch)
        print(
            f"Epoch [{epoch+1}/{config.n_epochs}], Loss: {epoch_loss_sum/(batch_idx+1):.4f}"
        )

    return superencoder


def main():
    OmegaConf.register_new_resolver("eval", eval)
    config = OmegaConf.load(config_file_path)

    seed_everything(config.seed)

    dataset_name = config.get("dataset_name", "FractalDB")
    if dataset_name == "FractalDB":
        dataset = FractalDB_Dataset(**config.dataset)
    elif dataset_name == "MNIST":
        dataset = MNISTDataset(**config.dataset)
    else:
        raise NotImplementedError(f"Unsupported dataset : f{dataset_name}")
    loader = DataLoader(dataset, shuffle=True, **config.dataloader)

    if not os.path.exists(config.checkpoint.logs):
        os.makedirs(config.checkpoint.logs)

    model = train(config, loader)

    model.save(config.checkpoint.save_path)


if __name__ == "__main__":
    main()
