# If USING PyTorch Lightning
import torch
import torchvision
import numpy as np
import pytorch_lightning as pl
import argparse
import os

from modules.resnet import get_resnet
from modules.resnet_spiking import get_resnet_spiking
from loss import BarlowTwinsLoss, BarlowTwinsTemporalLoss

from knn_eval import knn_predict, BenchmarkModule
from BartonTwins import BartonTwins
from BartonTwins_spiking import BartonTwinsSpiking
from utils import yaml_config_hook
from modules.transformations import DataTransforms


class BartonTwinsModel(BenchmarkModule):
    def __init__(self, args, dataloader_kNN, gpus, classes, knn_k, knn_t):
        super().__init__(dataloader_kNN, gpus, classes, knn_k, knn_t)

        self.args = args
        # create a ResNet backbone and remove the classification head
        self.backbone = get_resnet(args.model, args.n_classes)
        n_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.barton_twins = BartonTwins(self.backbone, n_features)
        self.criterion = BarlowTwinsLoss(device=args.device)

    def forward(self, x1, x2):
        return self.barton_twins(x1, x2)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        _, z_a, _, z_b = self(x0, x1)  # self->call->forward
        loss = self.criterion(z_a, z_b)
        self.log('train_loss_ssl', loss)
        return loss

    # learning rate warm-up
    def optimizer_steps(self,
                        epoch=None,

                        batch_idx=None,
                        optimizer=None,
                        optimizer_idx=None,
                        optimizer_closure=None,
                        on_tpu=None,
                        using_native_amp=None,
                        using_lbfgs=None):
        # 120 steps ~ 1 epoch
        if self.trainer.global_step < 1000:
            lr_scale = min(1., float(self.trainer.global_step + 1) / 1000.)
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * 1e-3

        # update params
        optimizer.step()
        optimizer.zero_grad()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.barton_twins.parameters(), lr=1e-3,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.args.epochs)
        return [optim], [scheduler]


class BartonTwinsSpikingModel(BenchmarkModule):
    def __init__(self, args, dataloader_kNN, gpus, classes, knn_k, knn_t):
        super().__init__(dataloader_kNN, gpus, classes, knn_k, knn_t)

        self.args = args
        # create a ResNet backbone and remove the classification head
        self.backbone = get_resnet_spiking(args.model, args.timestep, args.n_classes)
        n_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.barton_twins = BartonTwinsSpiking(self.backbone, n_features, args.timestep)
        self.criterion = BarlowTwinsTemporalLoss(device=args.device)

    def forward(self, x1, x2):
        return self.barton_twins(x1, x2)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        _, _,z_a, z_b, _, _ = self(x0, x1)  # self->call->forward
        loss = self.criterion(z_a, z_b)
        self.log('train_loss_ssl', loss)
        return loss

    # learning rate warm-up
    def optimizer_steps(self,
                        epoch=None,

                        batch_idx=None,
                        optimizer=None,
                        optimizer_idx=None,
                        optimizer_closure=None,
                        on_tpu=None,
                        using_native_amp=None,
                        using_lbfgs=None):
        # 120 steps ~ 1 epoch
        if self.trainer.global_step < 1000:
            lr_scale = min(1., float(self.trainer.global_step + 1) / 1000.)
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * 1e-3

        # update params
        optimizer.step()
        optimizer.zero_grad()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.barton_twins.parameters(), lr=1e-3,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.args.epochs)
        return [optim], [scheduler]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))

    args = parser.parse_args()

    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.num_gpus = torch.cuda.device_count()
    args.lr = float(args.lr)
    print(vars(args))

    pl.seed_everything(args.seed)

    if args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            download=True,
            transform=DataTransforms(size=args.image_size),
        )
        train_dataset_knn = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=args.image_size).test_transform,
        )
        dataset_test = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=False,
            download=True,
            transform=DataTransforms(size=args.image_size).test_transform,
        )
    elif args.dataset == "CIFAR100":
        train_dataset = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            download=True,
            transform=DataTransforms(size=args.image_size),
        )
        train_dataset_knn = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=args.image_size).test_transform,
        )
        dataset_test = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            train=False,
            download=True,
            transform=DataTransforms(size=args.image_size).test_transform,
        )
    else:
        raise NotImplementedError

    dataloader_train_ssl = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers
    )
    dataloader_train_kNN = torch.utils.data.DataLoader(
        train_dataset_knn,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.num_workers
    )
    dataloader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.num_workers
    )

    if args.spiking:
        model = BartonTwinsSpikingModel(args, dataloader_train_kNN, gpus=args.gpus, classes=args.n_classes,
                                        knn_k=args.knn_k, knn_t=args.knn_t)
    else:
        model = BartonTwinsModel(args, dataloader_train_kNN, gpus=args.gpus, classes=args.n_classes, knn_k=args.knn_k,
                                 knn_t=args.knn_t)
    trainer = pl.Trainer(max_epochs=args.epochs)
    trainer.fit(
        model,
        train_dataloaders=dataloader_train_ssl,
        val_dataloaders=dataloader_test
    )

    print(f'Highest test accuracy: {model.max_accuracy:.4f}')
