import os
import argparse
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

from modules import get_resnet
from modules import get_resnet_spiking, modify_resnet_model
from modules import get_vgg
from modules import get_vgg_spiking
from modules.transformations import DataTransforms
from modules import LogisticRegression

from BartonTwins import BartonTwins
from BartonTwins_spiking import BartonTwinsSpiking
from model import load_optimizer, save_model
from utils import yaml_config_hook

import modules.vit
from timm.models import create_model
from spikingjelly.clock_driven import functional


def inference(args, loader, bt_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        if args.spiking:
            with torch.no_grad():
                h, _, _, _, _, _ = bt_model(x, x)  # h is timestep augmented
        else:
            with torch.no_grad():
                h, _, _, _ = bt_model(x, x)

        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        if args.spiking:
            b_size = y.shape[0]
            y_temp = np.zeros((args.timestep * b_size,) + y.shape[1:])
            for t in range(args.timestep):
                y_temp[t * b_size:(t + 1) * b_size, ...] = y
            labels_vector.extend(y_temp)
        else:
            labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(args, bt_model, train_loader, test_loader, device):
    train_X, train_y = inference(args, train_loader, bt_model, device)
    test_X, test_y = inference(args, test_loader, bt_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader


def train(args, loader, bt_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        if args.spiking:
            b_size = int(y.shape[0] / args.timestep)
            y = y[0:b_size, ...].to(args.device)
        else:
            y = y.to(args.device)

        output = model(x)
        if args.spiking:
            o = torch.zeros((b_size,) + output.shape[1:], device=output.device)
            for t in range(args.timestep):
                o += output[t * b_size:(t + 1) * b_size, ...]
            output = o / args.timestep
        loss = criterion(output, y.long())

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    return loss_epoch, accuracy_epoch


def test(args, loader, bt_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        if args.spiking:
            b_size = int(y.shape[0] / args.timestep)
            y = y[0:b_size, ...].to(args.device)
        else:
            y = y.to(args.device)

        output = model(x)
        if args.spiking:
            o = torch.zeros((b_size,) + output.shape[1:], device=output.device)
            for t in range(args.timestep):
                o += output[t * b_size:(t + 1) * b_size, ...]
            output = o / args.timestep
        loss = criterion(output, y.long())

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.lr = float(args.lr)
    print(vars(args))

    if args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=32).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=False,
            download=True,
            transform=DataTransforms(size=32).test_transform,
        )
    elif args.dataset == "CIFAR100":
        train_dataset = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=32).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            train=False,
            download=True,
            transform=DataTransforms(size=32).test_transform,
        )
    elif args.dataset == "ImageNet":
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(args.dataset_dir, 'train'),
            transform=DataTransforms(size=args.image_size).test_transform,
        )
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(args.dataset_dir, 'val'),
            transform=DataTransforms(size=args.image_size).test_transform,
        )
    else:
        raise NotImplementedError

    # Create a 1%/10%/100% data subset using the random indices
    num_samples = int(len(train_dataset) * args.eval_proportion)
    indices = torch.randperm(len(train_dataset))[:num_samples]
    train_subset = torch.utils.data.Subset(train_dataset, indices)

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=args.logistic_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.logistic_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
    )

    # initialize ResNet, encoder is vit
    backbone = create_model(
        'spikformer',
        pretrained=False,
        drop_rate=0.,
        drop_path_rate=0.,
        drop_block_rate=None,
        img_size_h=32, img_size_w=32,
        patch_size=128, embed_dims=384, num_heads=12, mlp_ratios=4,
        in_channels=3, num_classes=10, qkv_bias=False,
        depths=4, sr_ratios=1,
        T=args.timestep
    )
    n_features = 384  # get dimensions of fc layer
    backbone.head = nn.Identity()
    bt_model = BartonTwinsSpiking(backbone, in_dim=n_features, out_dim=args.projection_dim, timestep=args.timestep)

    model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.epoch_num))
    bt_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
    bt_model = bt_model.to(args.device)
    bt_model.eval()

    ## Logistic Regression
    n_classes = args.n_classes  # CIFAR-10 / STL-10 / CIFAR-100
    model = LogisticRegression(n_features, n_classes)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    print("### Creating features from pre-trained context model ###")
    (train_X, train_y, test_X, test_y) = get_features(
        args, bt_model, train_loader, test_loader, args.device
    )

    if args.spiking:
        arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
            train_X, train_y, test_X, test_y, args.logistic_batch_size * args.timestep  # time augmented
        )
    else:
        arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
            train_X, train_y, test_X, test_y, args.logistic_batch_size
        )

    for epoch in range(args.logistic_epochs):
        loss_epoch, accuracy_epoch = train(
            args, arr_train_loader, bt_model, model, criterion, optimizer
        )
        if epoch % 20 == 0:
            print(
                f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(arr_train_loader)}\t Accuracy: {accuracy_epoch / len(arr_train_loader)}"
            )

    # final testing
    loss_epoch, accuracy_epoch = test(
        args, arr_test_loader, bt_model, model, criterion, optimizer
    )
    print(
        f"[FINAL]\t Loss: {loss_epoch / len(arr_test_loader)}\t Accuracy: {accuracy_epoch / len(arr_test_loader)}"
    )
