import os
import argparse
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

from BartonTwins import BartonTwins
from BartonTwins_spiking import BartonTwinsSpiking
from model import load_optimizer, save_model
from utils import yaml_config_hook

from modules.transformations import DataTransforms
from modules import get_resnet, get_resnet_spiking, modify_resnet_model, get_vgg, get_vgg_spiking, LogisticRegression
from modules.spike_layer import MixedLIF, LIFt


def inference(args, loader, bt_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        with torch.no_grad():
            h, _, _, _ = bt_model(x, x)
        h = h.flatten(0, 1)
        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        if args.spiking:
            b_size = y.shape[0]
            y_temp = np.zeros((args.timestep * b_size,) + y.shape[1:])
            for t in range(args.timestep):
                y_temp[t * b_size:(t + 1) * b_size, ...] = y
            labels_vector.extend(y_temp)
        else:
            labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(args, bt_model, train_loader, test_loader, device):
    train_X, train_y = inference(args, train_loader, bt_model, device)
    test_X, test_y = inference(args, test_loader, bt_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader


def train(args, loader, bt_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        if args.spiking:
            b_size = int(y.shape[0] / args.timestep)
            y = y[0:b_size, ...].to(args.device)
        else:
            y = y.to(args.device)

        output = model(x)
        if args.spiking:
            o = torch.zeros((b_size,) + output.shape[1:], device=output.device)
            for t in range(args.timestep):
                o += output[t * b_size:(t + 1) * b_size, ...]
            output = o / args.timestep
        loss = criterion(output, y.long())

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    return loss_epoch, accuracy_epoch


def test(args, loader, bt_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        if args.spiking:
            b_size = int(y.shape[0] / args.timestep)
            y = y[0:b_size, ...].to(args.device)
        else:
            y = y.to(args.device)

        output = model(x)
        if args.spiking:
            o = torch.zeros((b_size,) + output.shape[1:], device=output.device)
            for t in range(args.timestep):
                o += output[t * b_size:(t + 1) * b_size, ...]
            output = o / args.timestep
        loss = criterion(output, y.long())

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.lr = float(args.lr)
    print(vars(args))

    if args.dataset == "CIFAR10":
        args.n_classes = 10
        args.img_size = 32
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=args.img_size).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=False,
            download=True,
            transform=DataTransforms(size=args.img_size).test_transform,
        )
    elif args.dataset == "oxford_flowers102":
        args.n_classes = 102
        args.img_size = 32
        train_dataset = torchvision.datasets.Flowers102(
            root=args.dataset_dir,
            split='train',
            download=True,
            transform=DataTransforms(size=args.img_size).test_transform,
        )
        test_dataset = torchvision.datasets.Flowers102(
            root=args.dataset_dir,
            split='test',
            download=True,
            transform=DataTransforms(size=args.img_size).test_transform,
        )
    else:
        raise NotImplementedError

    # Create a 1%/10%/100% data subset using the random indices
    num_samples = int(len(train_dataset) * args.eval_proportion)
    indices = torch.randperm(len(train_dataset))[:num_samples]
    train_subset = torch.utils.data.Subset(train_dataset, indices)

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=args.logistic_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.logistic_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
    )

    # initialize active function of lif
    if args.act_func == 'MixedLIF':
        Act_func = MixedLIF
    else:
        Act_func = LIFt

    # initialize ResNet, encoder is resnet/resnet_snn
    # load pre-trained model from checkpoint
    if args.spiking:
        if "resnet" in args.model:
            backbone = get_resnet_spiking(args.model, args.timestep, args.sync_norm, Act_func, args.n_classes)
            n_features = backbone.fc.in_features  # get dimensions of fc layer
            backbone.fc = nn.Identity()
        bt_model = BartonTwinsSpiking(backbone, in_dim=n_features, out_dim=args.projection_dim, act_func=Act_func, timestep=args.timestep)
    else:
        if "resnet" in args.model:
            backbone = get_resnet(args.model, args.n_classes)
            n_features = backbone.fc.in_features
            backbone.fc = nn.Identity()
        bt_model = BartonTwins(backbone, in_dim=n_features, out_dim=args.projection_dim)

    # reload from checkpoint
    model_fp = os.path.join(
        args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
    )
    checkpoint = torch.load(model_fp, map_location=args.device)
    bt_model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model and optimizer loaded from checkpoint '{model_fp}' at epoch {checkpoint['epoch']}.")
    bt_model = bt_model.to(args.device)
    bt_model.eval()

    ## Logistic Regression
    n_classes = args.n_classes  # CIFAR-10 / oxford_flowers102
    model = LogisticRegression(n_features, n_classes)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    print("### Creating features from pre-trained context model ###")
    (train_X, train_y, test_X, test_y) = get_features(
        args, bt_model, train_loader, test_loader, args.device
    )

    if args.spiking:
        arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
            train_X, train_y, test_X, test_y, args.logistic_batch_size * args.timestep  # time augmented
        )
    else:
        arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
            train_X, train_y, test_X, test_y, args.logistic_batch_size
        )

    for epoch in range(args.logistic_epochs):
        loss_epoch, accuracy_epoch = train(
            args, arr_train_loader, bt_model, model, criterion, optimizer
        )
        if epoch % 20 == 0:
            print(
                f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(arr_train_loader)}\t Accuracy: {accuracy_epoch / len(arr_train_loader)}"
            )

    # final testing
    loss_epoch, accuracy_epoch = test(
        args, arr_test_loader, bt_model, model, criterion, optimizer
    )
    print(
        f"[FINAL]\t Loss: {loss_epoch / len(arr_test_loader)}\t Accuracy: {accuracy_epoch / len(arr_test_loader)}"
    )
