import torch
import torch.nn as nn
import torchvision
import numpy as np
import argparse
import os
import matplotlib.pyplot as plt

from modules.resnet import get_resnet
from modules.resnet_spiking import get_resnet_spiking
from modules.vgg import get_vgg
from modules.vgg_spiking import get_vgg_spiking
from loss import BarlowTwinsLoss, BarlowTwinsTemporalLoss

from BartonTwins import BartonTwins
from BartonTwins_spiking import BartonTwinsSpiking
from utils import yaml_config_hook
from model import load_optimizer, save_model
from modules.transformations import DataTransforms

# TensorBoard
from torch.utils.tensorboard import SummaryWriter


def train(args, train_loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)

        # positive pair, with encoding
        if args.spiking:
            h_i, h_j, z_i, z_j, z_i_temporal, z_j_temporal = model(x_i, x_j)
        else:
            h_i, h_j, z_i, z_j = model(x_i, x_j)

        if args.temporal_loss and args.spiking:
            loss = criterion(z_i_temporal, z_j_temporal)
        else:
            loss = criterion(z_i, z_j)

        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        args.global_step += 1

        loss_epoch += loss.item()
    return loss_epoch


def main(gpu, args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            download=True,
            transform=DataTransforms(size=args.image_size),
        )
    elif args.dataset == "CIFAR100":
        train_dataset = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            download=True,
            transform=DataTransforms(size=args.image_size),
        )
    else:
        raise NotImplementedError

    dataloader_train_ssl = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers
    )

    # initialize ResNet, encoder is resnet/resnet_snn
    if args.spiking:
        if "resnet" in args.model:
            backbone = get_resnet_spiking(args.model, args.timestep, args.n_classes)
            n_features = backbone.fc.in_features  # get dimensions of fc layer
            backbone.fc = nn.Identity()
        elif "vgg" in args.model:
            backbone = get_vgg_spiking(args.model, args.timestep, args.n_classes)
            n_features = backbone.fc[0].in_features
            backbone.fc = nn.Identity()
        model = BartonTwinsSpiking(backbone, in_dim=n_features, timestep=args.timestep)
    else:
        if "resnet" in args.model:
            backbone = get_resnet(args.model, args.n_classes)
            n_features = backbone.fc.in_features
            backbone.fc = nn.Identity()
        elif "vgg" in args.model:
            backbone = get_vgg(args.model, args.n_classes)
            n_features = backbone.fc[0].in_features
            backbone.fc = nn.Identity()
        model = BartonTwins(backbone, in_dim=n_features)

    if args.reload:
        model_fp = os.path.join(
            args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
        )
        model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
    model = model.to(args.device)

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, model)
    if args.temporal_loss and args.spiking:
        criterion = BarlowTwinsTemporalLoss(args.device, args.timestep, args.cross_temporal)
    else:
        criterion = BarlowTwinsLoss(device=args.device)

    model = model.to(args.device)
    writer = SummaryWriter()

    ls_l = []
    args.global_step = 0
    args.current_epoch = args.start_epoch
    for epoch in range(args.start_epoch, args.epochs):
        print('===================================')

        lr = optimizer.param_groups[0]["lr"]
        loss_epoch = train(args, dataloader_train_ssl, model, criterion, optimizer, writer)

        if scheduler:
            scheduler.step()

        writer.add_scalar("Loss/train", loss_epoch / len(dataloader_train_ssl), epoch)
        writer.add_scalar("Misc/learning_rate", lr, epoch)
        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(dataloader_train_ssl)}\t lr: {round(lr, 7)}"
        )
        args.current_epoch += 1
        ls_l.append(loss_epoch / len(dataloader_train_ssl))
    return ls_l


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))

    args = parser.parse_args()

    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.num_gpus = torch.cuda.device_count()
    args.lr = float(args.lr)
    print(vars(args))
    lr_list = [1e-3, 1e-4, 1e-5, 1e-6]
    plt.figure()

    for lr in lr_list:
        args.lr = lr
        args.epochs = 5
        args.start_epoch = 0
        ls_list = main(0, args)
        plt.plot(ls_list, label="{}".format(lr))
    plt.legend()
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss Value')
    plt.show()


