import os

from spikingjelly.activation_based import functional
import torch
from torch import nn

from utils import AverageMeter, get_logger
from tqdm import tqdm
import torchvision
from torch.cuda.amp import autocast
import torch.nn.functional as F
from dataset import crop_bsds, pad_nyud


def train_bsds(
        train_loader,
        model,
        opt,
        lr_schd,
        print_freq,
        max_epoch,
        epoch,
        save_dir,
        logger,
        device,
        loss,
        scaler,
):
    # folder
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # trainer
    model.train()

    batch_loss_meter = AverageMeter()

    # deep loss
    deep_loss = DeepSupervisionModel(loss)

    progress_bar = tqdm(train_loader, ncols=200, desc=f"Epoch {epoch}")
    for batch_index, data in enumerate(progress_bar):
        images, labels = data["images"].to(device), data["labels"].to(device)
        images, labels = crop_bsds(images), crop_bsds(labels)  # BSDS500
        opt.zero_grad()

        if scaler is not None:
            with ((autocast())):
                preds = model(images)
                batch_loss = deep_loss(preds, labels)

            scaler.scale(batch_loss).backward()
            scaler.step(opt)
            scaler.update()
        else:
            preds = model(images)
            batch_loss = deep_loss(preds, labels)
            batch_loss.backward()
            opt.step()

        functional.reset_net(model)

        # 显示loss
        progress_bar.set_postfix(train_loss_step=f"{batch_loss.item():.4f}")

        # 记录loss
        batch_loss_meter.update(batch_loss.item())
        if batch_index % print_freq == print_freq - 1:
            logger.info(
                (
                        "Training epoch:{}/{}, batch:{}/{} current iteration:{}, "
                        + "current batch batch_loss:{}, epoch average batch_loss:{}, learning rate list:{}."
                ).format(
                    epoch + 1,
                    max_epoch,
                    batch_index + 1,
                    len(train_loader),
                    lr_schd.last_epoch + 1,
                    batch_loss_meter.val,
                    batch_loss_meter.avg,
                    lr_schd.get_last_lr(),
                )
            )

            preds_list_and_edges = [preds[0].sigmoid()] + [labels]
            height, width = preds_list_and_edges[0].shape[2:]
            interm_image = torch.zeros((len(preds_list_and_edges), 1, height, width))
            for i in range(len(preds_list_and_edges)):
                interm_image[i, 0, :, :] = preds_list_and_edges[i][0, 0, :, :]
            torchvision.utils.save_image(
                interm_image,
                os.path.join(save_dir, "batch-{}-1st-image.png".format(batch_index)),
            )

    return batch_loss_meter.avg


def train_voc(
        train_loader,
        model,
        opt,
        lr_schd,
        print_freq,
        max_epoch,
        epoch,
        save_dir,
        logger,
        device,
        loss,
):
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    model.train()

    batch_loss_meter = AverageMeter()

    for batch_index, data in enumerate(tqdm(train_loader)):
        images, labels = data["images"].to(device), data["labels"].to(device)
        opt.zero_grad()
        preds = model(images)
        batch_loss = loss(preds, labels)
        batch_loss.backward()
        opt.step()

        batch_loss_meter.update(batch_loss.item())
        if batch_index % print_freq == print_freq - 1:
            logger.info(
                (
                        "Training epoch:{}/{}, batch:{}/{} current iteration:{}, "
                        + "current batch batch_loss:{}, epoch average batch_loss:{}, learning rate list:{}."
                ).format(
                    epoch + 1,
                    max_epoch,
                    batch_index,
                    len(train_loader),
                    lr_schd.last_epoch + 1,
                    batch_loss_meter.val,
                    batch_loss_meter.avg,
                    lr_schd.get_last_lr(),
                )
            )
            preds_list_and_edges = preds + [labels]
            height, width = preds_list_and_edges[0].shape[2:]
            interm_image = torch.zeros((len(preds_list_and_edges), 1, height, width))
            for i in range(len(preds_list_and_edges)):
                interm_image[i, 0, :, :] = preds_list_and_edges[i][0, 0, :, :]
            torchvision.utils.save_image(
                interm_image,
                os.path.join(save_dir, "batch-{}-1st-image.png".format(batch_index)),
            )

    return batch_loss_meter.avg


def train_nyud(
        train_loader,
        model,
        opt,
        lr_schd,
        print_freq,
        max_epoch,
        epoch,
        save_dir,
        logger,
        device,
        loss,
):
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    model.train()

    batch_loss_meter = AverageMeter()

    for batch_index, data in enumerate(tqdm(train_loader)):
        images, labels = data["images"].to(device), data["labels"].to(device)
        images, labels = pad_nyud(images), pad_nyud(labels)
        opt.zero_grad()
        preds = model(images)
        batch_loss = loss(preds, labels)
        batch_loss.backward()
        opt.step()

        batch_loss_meter.update(batch_loss.item())
        if batch_index % print_freq == print_freq - 1:
            logger.info(
                (
                        "Training epoch:{}/{}, batch:{}/{} current iteration:{}, "
                        + "current batch batch_loss:{}, epoch average batch_loss:{}, learning rate list:{}."
                ).format(
                    epoch + 1,
                    max_epoch,
                    batch_index,
                    len(train_loader),
                    lr_schd.last_epoch + 1,
                    batch_loss_meter.val,
                    batch_loss_meter.avg,
                    lr_schd.get_last_lr(),
                )
            )
            preds_list_and_edges = preds + [labels]
            height, width = preds_list_and_edges[0].shape[2:]
            interm_image = torch.zeros((len(preds_list_and_edges), 1, height, width))
            for i in range(len(preds_list_and_edges)):
                interm_image[i, 0, :, :] = preds_list_and_edges[i][0, 0, :, :]
            torchvision.utils.save_image(
                interm_image,
                os.path.join(save_dir, "batch-{}-1st-image.png".format(batch_index)),
            )

    return batch_loss_meter.avg


class DeepSupervisionModel(nn.Module):
    def __init__(self, loss):
        super(DeepSupervisionModel, self).__init__()
        self.raw_weights = nn.Parameter(torch.ones(5), requires_grad=True)
        self.loss = loss

    def forward(self, preds, labels):
        weights = torch.softmax(self.raw_weights, dim=0)
        batch_loss = weights[0] * self.loss(preds[0], labels) + \
                     weights[1] * self.loss(preds[1], labels) + \
                     weights[2] * self.loss(preds[2], labels) + \
                     weights[3] * self.loss(preds[3], labels) + \
                     weights[4] * self.loss(preds[4], labels)
        return batch_loss
