import numpy as np
import argparse
import time
import yaml
import os
import datetime

import torch
import torch.nn as nn
import torchvision.utils
from loader import create_loader
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint
from timm.utils import *

from transforms_factory import contrastive_learning_transforms as DataTransforms
import model
import barlowtwins

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler


def cleanup():
    dist.destroy_process_group()

def reduce_tensor(tensor, world_size):
    if world_size > 1:
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= world_size
        return rt
    else:
        return tensor


config_parser = parser = argparse.ArgumentParser(description='Eval Config', add_help=False)
parser.add_argument('-c', '--config', default='imagenet.yml', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')

parser = argparse.ArgumentParser(description='PyTorch ')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
                    help='number of label classes (Model default if None)')
parser.add_argument('--lr', default=1e-3, type=float,
                    help='learning rate')
parser.add_argument('--dataset', '-d', metavar='NAME', default='dataset',
                    help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--eval_proportion', default='1', type=float,
                    help='proportion of test dataset for evaluation')
parser.add_argument('--local-rank', default=0, type=int, help='Local rank for DistributedDataParallel')



def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text

class LogisticRegression(nn.Module):
    def __init__(self, n_features, n_classes):
        super(LogisticRegression, self).__init__()

        self.model = nn.Linear(n_features, n_classes)

    def forward(self, x):
        return self.model(x)


class MLP_classifation(nn.Module):
    def __init__(self, dim, num_classes):
        super(MLP_classifation, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(dim, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.fc(x)
        return x


def inference(args, loader, bt_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)
        with torch.no_grad():
            h, _, _, _ = bt_model(x, x)
        h = h.mean(0) .detach()  # [B, C]

        feature_vector.extend(h.cpu().detach().numpy())

        # expand labels_vector
        labels_vector.extend(y)

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(args, bt_model, train_loader, test_loader, device):
    train_X, train_y = inference(args, train_loader, bt_model, device)
    test_X, test_y = inference(args, test_loader, bt_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False, num_workers=4
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False, num_workers=4,
    )
    return train_loader, test_loader


def train(args, loader, model, criterion, optimizer, world_size=1):
    loss_epoch = 0
    accuracy_epoch_top1 = 0
    accuracy_epoch_top5 = 0
    for step, (x, y) in enumerate(loader):
        T, B = args.time_step, int(x.shape[0]/args.time_step)
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)  # [B, num_classes]
        loss = criterion(output, y.long())

        # top-1
        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch_top1 += acc
        # top-5
        _, pred = output.topk(5, 1, True, True)
        pred = pred.t()
        acc_5 = pred.eq(y[None])
        acc_5 = acc_5.flatten().sum(dtype=torch.float32) / y.size(0)
        accuracy_epoch_top5 += acc_5

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    # return loss_epoch, accuracy_epoch_top1, accuracy_epoch_top5
    loss_epoch = reduce_tensor(torch.tensor(loss_epoch, device=args.device), world_size)
    accuracy_epoch_top1 = reduce_tensor(torch.tensor(accuracy_epoch_top1, device=args.device), world_size)
    accuracy_epoch_top5 = reduce_tensor(torch.tensor(accuracy_epoch_top5, device=args.device), world_size)

    return loss_epoch.item(), accuracy_epoch_top1.item(), accuracy_epoch_top5.item()


def test(args, loader, model, criterion):
    loss_epoch = 0
    accuracy_epoch_top1 = 0
    accuracy_epoch_top5 = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        T, B = args.time_step, int(x.shape[0]/args.time_step)
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)  # [B, num_classes]
        loss = criterion(output, y.long())

        # top-1
        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch_top1 += acc
        # top-5
        _, pred = output.topk(5, 1, True, True)
        pred = pred.t()
        acc_5 = pred.eq(y[None])
        acc_5 = acc_5.flatten().sum(dtype=torch.float32) / y.size(0)
        accuracy_epoch_top5 += acc_5

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch_top1, accuracy_epoch_top5


def main(rank, world_size, args):
    # init DDP
    args.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    if world_size > 1:
        dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=120))
        torch.cuda.set_device(rank)

    if rank == 0:
        print("Now start loading dataset")

    if 'imagenet' in args.dataset:
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(args.dataset, 'train'),
            transform=DataTransforms(size=args.img_size).train_evaluation_transform,
        )
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(args.dataset, 'val'),
            transform=DataTransforms(size=args.img_size).test_transform,
        )
    else:
        raise NotImplementedError

    # Create a 1%/10%/100% data subset using the random indices
    num_samples = int(len(train_dataset) * args.eval_proportion)
    indices = torch.randperm(len(train_dataset))[:num_samples]
    train_subset = torch.utils.data.Subset(train_dataset, indices)

    if world_size > 1:
        sampler = DistributedSampler(train_subset, num_replicas=world_size, rank=rank, shuffle=True)
    else:
        sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=args.val_batch_size,
        shuffle=(sampler is None),
        drop_last=True,
        num_workers=4,
        sampler=sampler
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.val_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=4,
    )

    # initialize ResNet, encoder is resnet/resnet_snn
    # load pre-trained model from checkpoint
    bt_model = create_model(
        'barlow_twins_spikformer',
        pretrained=False,
        drop_rate=0.,
        drop_path_rate=0.,
        drop_block_rate=None,
        img_size_h=args.img_size, img_size_w=args.img_size,
        patch_size=args.patch_size, embed_dims=args.dim, num_heads=args.num_heads, mlp_ratios=args.mlp_ratio,
        in_channels=3, num_classes=args.num_classes, qkv_bias=False,
        depths=args.layer, sr_ratios=1,
        act_func='LIFt',
        T=args.time_step
    )

    # bt_model.load_state_dict(torch.load(args.resume, map_location=args.device.type))
    resume_epoch = resume_checkpoint(
        bt_model, args.resume,
        optimizer=None,
        loss_scaler=None,
        log_info=0)

    bt_model = bt_model.to(args.device)
    bt_model.eval()

    ## Logistic Regression
    model = LogisticRegression(args.dim, args.num_classes)
    model = model.to(args.device)
    if world_size > 1:  # DDP
        model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = torch.nn.CrossEntropyLoss()

    print("### Creating features from pre-trained context model ###")
    (train_X, train_y, test_X, test_y) = get_features(
        args, bt_model, train_loader, test_loader, args.device
    )

    arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
        train_X, train_y, test_X, test_y, args.val_batch_size
    )

    for epoch in range(100):
        if world_size > 1:  # shuffle
            sampler.set_epoch(epoch)
        loss_epoch, accuracy_epoch_top1, accuracy_epoch_top5= train(
            args, arr_train_loader, model, criterion, optimizer, world_size
        )
        if epoch % 20 == 0:
            print(
                f"Epoch [{epoch}/{args.val_epochs}]\t Loss: {loss_epoch / len(arr_train_loader)}\t Accuracy top-1: {accuracy_epoch_top1 / len(arr_train_loader)}\t Accuracy top-5: {accuracy_epoch_top5 / len(arr_train_loader)}"
            )

    # final testing
    loss_epoch, accuracy_top1, accuracy_top5 = test(
        args, arr_test_loader, model, criterion
    )
    print(
        f"[FINAL]\t Loss: {loss_epoch / len(arr_test_loader)}\t Accuracy top-1: {accuracy_top1 / len(arr_test_loader)}\t Accuracy top-5: {accuracy_top5 / len(arr_test_loader)}"
    )


if __name__ == "__main__":
    args, args_text = _parse_args()
    args.lr = float(args.lr)
    print(vars(args))


    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "1234"

    world_size = torch.cuda.device_count()
    if world_size > 1:
        torch.multiprocessing.spawn(main, args=(world_size, args), nprocs=world_size, join=True)
    else:
        main(0, 1, args)