import numpy as np
import argparse
import time
import yaml
import os

import torch
import torch.nn as nn
import torchvision.utils
from loader import create_loader
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
    convert_splitbn_model, model_parameters
from timm.utils import *

from transforms_factory import contrastive_learning_transforms as DataTransforms
import model
import barlowtwins


config_parser = parser = argparse.ArgumentParser(description='Eval Config', add_help=False)
parser.add_argument('-c', '--config', default='cifar10.yml', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')

parser = argparse.ArgumentParser(description='PyTorch ')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_proportion', default='0.01', type=float,
                    help='proportion of test dataset for evaluation')
parser.add_argument('--val_epochs', default='60', type=int,
                    help='For 1% labeled data we fine-tune for 60 epochs, and for 10% labeled data we fine-tune for 30 \
                         epochs')


def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)  # the settings in the config file override the default values in the parser.

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)  # The parameters specified on the command line override the above settings.

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text


def train(args, loader, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch_top1 = 0
    accuracy_epoch_top5 = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output, _ = model(x)
        loss = criterion(output, y)

        # top-1
        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch_top1 += acc
        # top-5
        _, pred = output.topk(5, 1, True, True)
        pred = pred.t()
        acc_5 = pred.eq(y[None])
        acc_5 = acc_5.flatten().sum(dtype=torch.float32) / y.size(0)
        accuracy_epoch_top5 += acc_5

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    return loss_epoch, accuracy_epoch_top1, accuracy_epoch_top5


def test(args, loader, model, criterion):
    loss_epoch = 0
    accuracy_epoch_top1 = 0
    accuracy_epoch_top5 = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        with torch.no_grad():
            output, _ = model(x)

        loss = criterion(output, y)

        # top-1
        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch_top1 += acc
        # top-5
        _, pred = output.topk(5, 1, True, True)
        pred = pred.t()
        acc_5 = pred.eq(y[None])
        acc_5 = acc_5.flatten().sum(dtype=torch.float32) / y.size(0)
        accuracy_epoch_top5 += acc_5

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch_top1, accuracy_epoch_top5

if __name__ == "__main__":

    args, args_text = _parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.lr = float(args.lr)
    print(vars(args))

    if args.dataset == "torch/cifar10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset,
            train=True,
            download=True,
            transform=DataTransforms(size=args.img_size).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR10(
            args.dataset,
            train=False,
            download=True,
            transform=DataTransforms(size=args.img_size).test_transform,
        )
    else:
        raise NotImplementedError

    # Create a 1%/10%/100% data subset using the random indices
    num_samples = int(len(train_dataset) * args.eval_proportion)
    indices = torch.randperm(len(train_dataset))[:num_samples]
    train_subset = torch.utils.data.Subset(train_dataset, indices)

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=args.val_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.val_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.workers,
    )

    # load pre-trained model from checkpoint
    model = create_model(
        'spikformer',
        pretrained=False,
        drop_rate=0.,
        drop_path_rate=0.,
        drop_block_rate=None,
        img_size_h=args.img_size, img_size_w=args.img_size,
        patch_size=args.patch_size, embed_dims=args.dim, num_heads=args.num_heads, mlp_ratios=args.mlp_ratio,
        in_channels=3, num_classes=args.num_classes, qkv_bias=False,
        depths=args.layer, sr_ratios=1,
        T=args.time_step
    )

    # Load weights from pre-trained file. Map the keys from the source model to the target model
    model_dict = model.state_dict()
    pretrained_dict = torch.load(args.resume, map_location=args.device.type)
    new_dict = {}
    for k, v in pretrained_dict['state_dict'].items():
        if "backbone." in k:
            # Remove the 'backbone.' prefix and use the rest as the key for the target model
            key = k.replace("backbone.", "")
            new_dict[key] = v

    model.load_state_dict(new_dict, strict=False)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(args.val_epochs):
        loss_epoch, accuracy_epoch_top1, accuracy_epoch_top5 = train(
            args, train_loader, model, criterion, optimizer
        )
        if epoch % 20 == 0:
            print(
                f"Epoch [{epoch}/{args.val_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy top-1: {accuracy_epoch_top1 / len(train_loader)}\t Accuracy top-5: {accuracy_epoch_top5 / len(train_loader)}"
            )
    # final testing
    loss_epoch, accuracy_top1, accuracy_top5 = test(
        args, test_loader, model, criterion
    )
    print(
        f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy top-1: {accuracy_top1 / len(test_loader)}\t Accuracy top-5: {accuracy_top5 / len(test_loader)}"
    )

