import itertools
import os
import sys
import time

from tqdm import tqdm
import wandb
import torch
from torch import Tensor
import torchvision
import torch.nn.functional as F

from se.configs import TrainConfig
from se.data import build_loaders
from se.models import build_model
from se.utils.noise_model import get_noise
from se.utils.metrics import psnr, ssim
from se.utils.train_utils import (
    save_config,
    save_model_weights,
    setup_experiment,
    run_name,
)
from se.utils.psnr_plot import plot_psnr_io
from experiments_cfg import *


def l2_sum(output: Tensor, target: Tensor) -> Tensor:
    return F.mse_loss(output, target, reduction="sum") / (target.size(0) * 2)


def main(cfg: TrainConfig):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    save_dir = setup_experiment(cfg)
    cfg.wandb_cfg.name = run_name(cfg) + f"_{time.strftime('%m%d_%H%M')}"
    config_dict = save_config(cfg, save_dir)

    # Build data loaders, a model and an optimizer
    model = build_model(cfg).to(device)
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    scheduler = None
    scheduler_step_mode = None
    if cfg.lr_halving_epochs:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=cfg.lr_halving_epochs, gamma=0.5
        )
        scheduler_step_mode = "epoch"
    elif cfg.lr_halving_steps:
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=cfg.lr_halving_steps, gamma=0.5
        )
        scheduler_step_mode = "step"

    print(
        f"Built a model consisting of {sum(p.numel() for p in model.parameters())/1e6:,}M parameters",
        flush=True,
    )
    loss_fn_dict = {
        "l2_sum": l2_sum,
        "l1": F.l1_loss,
        "l2": F.mse_loss,
    }
    loss_fn = loss_fn_dict[cfg.loss_type.lower()]

    wandb_run = None
    wandb_cfg = cfg.wandb_cfg
    if not wandb_cfg.mode == "disabled":
        wandb_run = wandb.init(
            **wandb_cfg.__dict__,
        )
        wandb_run.config.update(config_dict, allow_val_change=True)

    global_step = -1
    start_epoch = 0

    train_loader, valid_loader = build_loaders(cfg)

    stop_training = False
    best_valid_psnr = float("-inf")
    if cfg.num_steps is None:
        epoch_iter = range(start_epoch, cfg.num_epochs)
    else:
        epoch_iter = itertools.count(start_epoch)

    for epoch in tqdm(epoch_iter):
        # Training loop
        for batch_id, batch in enumerate(train_loader):
            model.train()

            global_step += 1
            clean_inputs = batch.to(device)
            noise = get_noise(
                clean_inputs,
                min_noise=cfg.min_noise / 255.0,
                max_noise=cfg.max_noise / 255.0,
                noise_type=cfg.noise_type,
            )
            noisy_inputs = noise + clean_inputs

            outputs = model(noisy_inputs)
            loss: Tensor = loss_fn(outputs, clean_inputs)

            model.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler is not None and scheduler_step_mode == "step":
                scheduler.step()
            if cfg.num_steps is not None and global_step + 1 >= cfg.num_steps:
                stop_training = True
                break

            should_log = cfg.log_interval <= 0 or global_step % cfg.log_interval == 0
            if wandb_run is not None and should_log:
                stats = {
                    "epoch": epoch,
                    "train/loss": loss.item(),
                    "train/lr": optimizer.param_groups[0]["lr"],
                }
                train_psnr = psnr(outputs, clean_inputs)
                train_ssim = ssim(outputs, clean_inputs)
                stats["train/psnr"] = train_psnr.item()
                stats["train/ssim"] = train_ssim.item()
                wandb.log(
                    stats,
                    step=global_step,
                )

        # Validation loop
        if epoch % cfg.valid_interval == 0:
            model.eval()
            valid_psnr_total = 0.0
            valid_ssim_total = 0.0
            valid_count = 0
            print(
                f"\nStarting validation at epoch {epoch=}, {global_step=}\n", flush=True
            )

            for sample_id, sample in enumerate(valid_loader):
                with torch.no_grad():
                    sample = sample.to(device)
                    noise = get_noise(
                        sample,
                        min_noise=(cfg.min_noise + cfg.max_noise) / (2 * 255.0),
                        max_noise=(cfg.min_noise + cfg.max_noise) / (2 * 255.0),
                        noise_type=cfg.noise_type,
                    )

                    noisy_inputs = noise + sample
                    output = model(noisy_inputs)
                    valid_psnr = psnr(output, sample)
                    valid_ssim = ssim(output, sample)
                    valid_psnr_total += valid_psnr.item()
                    valid_ssim_total += valid_ssim.item()
                    valid_count += 1

                    if wandb_run is not None and sample_id < 10:
                        image = torch.cat([sample, noisy_inputs, output], dim=0)
                        image = torchvision.utils.make_grid(
                            image.clamp(0, 1), nrow=3, normalize=False
                        )
                        wandb.log(
                            {
                                f"valid_samples/{sample_id}": wandb.Image(
                                    image.detach().cpu()
                                ),
                            },
                            step=global_step,
                        )

            avg_valid_psnr = valid_psnr_total / valid_count if valid_count > 0 else None
            avg_valid_ssim = valid_ssim_total / valid_count if valid_count > 0 else None

            if wandb_run is not None and valid_count > 0:
                wandb.log(
                    {
                        "valid/psnr": avg_valid_psnr,
                        "valid/ssim": avg_valid_ssim,
                        "valid/epoch": epoch,
                    },
                    step=global_step,
                )
                sys.stdout.flush()

            plot_path = os.path.join(save_dir, f"epoch_{epoch}_psnr_plot.png")
            training_sigma = (cfg.min_noise, cfg.max_noise)
            _, _, _, psnr_auc = plot_psnr_io(
                models=[model],
                data_dirs=cfg.test_path,
                device=str(device),
                training_sigma=training_sigma,
                sigma_values=cfg.psnr_eval_sigma_values,
                save_path=plot_path,
                n_averages=10,
                dataset_mode=cfg.train_dataset_type.lower(),
                noise_type=cfg.noise_type,
            )
            if wandb_run is not None:
                wandb.log(
                    {
                        "valid/psnr_plot": wandb.Image(plot_path),
                        "valid/psnr_auc": psnr_auc,
                    },
                    step=global_step,
                )

            if avg_valid_psnr is not None and avg_valid_psnr > best_valid_psnr:
                best_valid_psnr = avg_valid_psnr
                save_model_weights(model, save_dir, "best")
            save_model_weights(model, save_dir, "last")

        if scheduler is not None and scheduler_step_mode == "epoch":
            scheduler.step()
        if stop_training:
            break

    save_model_weights(model, save_dir, "last")

    if wandb_run is not None:
        wandb_run.finish()


if __name__ == "__main__":
    main(cfg_50_dncnn_wne)
