import os

import hydra
import numpy as np
import torch
import torchinfo
from hydra.core.global_hydra import GlobalHydra
from matplotlib import pyplot as plt
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from quant.utils import get_rundir
from utils.utils import get_train_val_dataset, to_np, limit_iterable


def run_validation(model, dataloader, criterion, device):
    print("Running validation")
    tot_loss = 0.0
    batch_count = 0
    model.eval()
    for batch in tqdm(dataloader):
        batch = batch.to(device)
        
        if model.use_vq:
            # On validation we report only reconstruction loss. 
            out, _ = model(batch)
        else:
            out = model(batch)
            
        loss = criterion(out, batch)
        tot_loss += loss.item()
        batch_count += 1
    return tot_loss / batch_count


def get_dataloader(dataset, batch_size):
    return DataLoader(dataset,
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=4,
                      pin_memory=True)


def plot_encodings(model, batch, plot_window):
    inp = batch[[0]]
    enc = model.encode(inp)
    out = model.decode(enc)
    assert inp.shape == out.shape

    max_window_start = inp.shape[1] - plot_window + 1
    window_start = torch.randint(0, max_window_start // 2, ()).item() * 2

    fig, (ax_real, ax_imag) = plt.subplots(1, 2, figsize=(10, 5))
    x_real = np.arange(window_start, window_start + plot_window, 2)
    x_imag = np.arange(window_start + 1, window_start + plot_window, 2)
    ax_real.plot(x_real // 2, to_np(inp)[0, x_real], label="signal")
    ax_real.plot(x_real // 2, to_np(out)[0, x_real], label="reconstruction")
    ax_real.legend()
    ax_real.set_xlabel("Sample id")
    ax_real.set_ylabel("Value")

    ax_imag.plot(x_imag // 2, to_np(inp)[0, x_imag], label="signal")
    ax_imag.plot(x_imag // 2, to_np(out)[0, x_imag], label="reconstruction")
    ax_imag.legend()
    ax_imag.set_xlabel("Sample id")
    ax_imag.set_ylabel("Value")

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.suptitle(f"[{inp.shape[1]}] -> [{enc.shape[1]}], tokens in {enc.min().item()}..{enc.max().item()}")
    return fig


@hydra.main(version_base="1.3", config_path="exp", config_name=None)
def main(cfg: DictConfig):
    torch.manual_seed(42)
    if cfg.training.test_run:
        print("TEST RUN!")

    if cfg.training.load_checkpoint:
        run_folder = f"quant_cnn{cfg.training.run_id:04}"
        run_path = os.path.join(f"runs", run_folder)
        device = cfg.training.device
        model = hydra.utils.instantiate(cfg.model).to(device)
        checkpoint_path = os.path.join(run_path, "model_best.ckpt")
        state_dict = torch.load(checkpoint_path)
        model.load_state_dict(state_dict)
    else:
        model = hydra.utils.instantiate(cfg.model).to(cfg.training.device)

    all_dataset = hydra.utils.instantiate(cfg.dataset)
    train_dataset, val_dataset = get_train_val_dataset(all_dataset, cfg.training.train_fraction)
    train_loader = get_dataloader(train_dataset, cfg.training.batch_size)
    val_loader = get_dataloader(val_dataset, cfg.training.batch_size)

    criterion = hydra.utils.instantiate(cfg.criterion)
    optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())

    if cfg.training.use_scheduler:
        scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer)

    tot_loss = 0
    steps_done = 0
    log_dir = os.path.join("runs", get_rundir("quant_cnn"))
    print("Logging directory", log_dir)
    logger = SummaryWriter(log_dir=log_dir)
    cfg_yaml = OmegaConf.to_yaml(cfg)
    logger.add_text("config", cfg_yaml)
    with open(os.path.join(log_dir, "config"), "w") as f:
        f.write(cfg_yaml)

    model_stats = torchinfo.summary(model, verbose=0)
    logger.add_text("summary", str(model_stats))
    with open(os.path.join(log_dir, "summary"), "w") as f:
        f.write(str(model_stats) + "\n")

    if cfg.training.use_ema:
        ema = hydra.utils.instantiate(cfg.ema, model)

    best_loss = None

    device = cfg.training.device
    batch = next(iter(val_loader)).to(device)

    fig = plot_encodings(model, batch, cfg.plot.window)
    logger.add_figure("test_reconstruction", fig)

    for epoch in range(cfg.training.n_epochs):
        model.train()
        train_iterable = limit_iterable(train_loader, 5) if cfg.training.test_run else train_loader
        for batch in tqdm(train_iterable):
            batch = batch.to(device)
            optimizer.zero_grad()
            
            if model.use_vq:
                out, quant_loss = model(batch)
                w_q = 1 if cfg.training.w_q is None else cfg.training.w_q
                loss = criterion(out, batch) + w_q * quant_loss
            else:
                out = model(batch)
                loss = criterion(out, batch)
            
            loss.backward()
            optimizer.step()
            if cfg.training.use_ema:
                ema.update()

            tot_loss += loss.item()
            steps_done += 1
            if steps_done % cfg.training.log_loss_every == 0:
                avg_loss = tot_loss / cfg.training.log_loss_every
                tot_loss = 0

                print(f"Average loss: {avg_loss:.4}")
                logger.add_scalar("train/avg_loss", avg_loss, steps_done)
                logger.add_scalar("train/loss", loss, steps_done)

        val_model = ema.ema_model if cfg.training.use_ema else model
        val_iterable = limit_iterable(val_loader, 5) if cfg.training.test_run else val_loader
        val_loss = run_validation(val_model, val_iterable, criterion, device)
        logger.add_scalar("val/loss", val_loss, epoch)
        logger.add_scalar("train/lr", optimizer.param_groups[0]["lr"], epoch)

        if cfg.training.use_scheduler:
            scheduler.step(val_loss)

        if epoch + 1 in cfg.plot.after_epochs or (epoch + 1) % cfg.plot.every_epochs == 0:
            batch = next(iter(val_loader)).to(device)
            fig = plot_encodings(val_model, batch, cfg.plot.window)
            logger.add_figure(f"reconstruction{epoch + 1}", fig)

        if best_loss is None or val_loss < best_loss:
            torch.save(val_model.state_dict(), os.path.join(log_dir, "model_best.ckpt"))
            best_loss = val_loss

        if epoch + 1 in cfg.training.checkpoint_after_epochs:
            torch.save(val_model.state_dict(), os.path.join(log_dir, f"model_{epoch + 1}.ckpt"))

    val_model = ema.ema_model if cfg.training.use_ema else model
    torch.save(val_model.state_dict(), os.path.join(log_dir, "model_final.ckpt"))
    logger.add_text("best_loss", str(best_loss))


if __name__ == "__main__":
    main()
