# https://github.com/zhirongw/lemniscate.pytorch/blob/master/test.py
import logging

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from ...data_preprocessing.Landmarks_per.data_loader import get_global_data_loader


class KNNValidation(object):
    def __init__(self, args, train_dataloader=None, val_dataloader=None, feat_dim=2048):
        self.feat_dim = feat_dim
        if train_dataloader is None:
            self.train_dataloader = self._build_train_data(args)
        else:
            self.train_dataloader = train_dataloader

        if val_dataloader is None:
            self.val_dataloader = self._build_test_data(args)
        else:
            self.val_dataloader = val_dataloader

        self.is_feature_generated = False

    def update_val_dataloader(self, val_dataloader):
        self.val_dataloader = val_dataloader

    def _build_train_data(self, args):

        if args.dataset == "cifar10":
            dataset_cls = datasets.CIFAR10
            base_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                     (0.24703233, 0.24348505, 0.26158768)),
            ])
            batch_size = args.batch_size * args.accumulation_steps
            train_dataset = dataset_cls(root=args.data_dir,
                                        train=True,
                                        download=True,
                                        transform=base_transforms)

            train_dataloader = DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=1,
                                          pin_memory=True,
                                          drop_last=True)

        elif args.dataset == "cifar100":
            dataset_cls = datasets.CIFAR100
            base_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4865, 0.4409),
                                     (0.2673, 0.2564, 0.2762)),
            ])

            batch_size = args.batch_size * args.accumulation_steps
            train_dataset = dataset_cls(root=args.data_dir,
                                        train=True,
                                        download=True,
                                        transform=base_transforms)

            train_dataloader = DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=1,
                                          pin_memory=True,
                                          drop_last=True)

        elif args.dataset == "gld23k_per":
            batch_size = args.batch_size * args.accumulation_steps
            train_dataloader, _ = get_global_data_loader(args, batch_size)

        else:
            raise Exception("no such dataset")

        return train_dataloader

    def _build_test_data(self, args):

        if args.dataset == "cifar10":
            dataset_cls = datasets.CIFAR10
            base_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                     (0.24703233, 0.24348505, 0.26158768)),
            ])

            batch_size = args.batch_size * args.accumulation_steps

            val_dataset = dataset_cls(root=args.data_dir,
                                      train=False,
                                      download=True,
                                      transform=base_transforms)

            val_dataloader = DataLoader(val_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=1,
                                        pin_memory=True,
                                        drop_last=True)

        elif args.dataset == "cifar100":
            dataset_cls = datasets.CIFAR100
            base_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4865, 0.4409),
                                     (0.2673, 0.2564, 0.2762)),
            ])

            batch_size = args.batch_size * args.accumulation_steps

            val_dataset = dataset_cls(root=args.data_dir,
                                      train=False,
                                      download=True,
                                      transform=base_transforms)

            val_dataloader = DataLoader(val_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=1,
                                        pin_memory=True,
                                        drop_last=True)
        elif args.dataset == "gld23k_per":
            batch_size = args.batch_size * args.accumulation_steps
            _, val_dataloader = get_global_data_loader(args, batch_size)

        else:
            raise Exception("no such dataset")

        return val_dataloader

    def _topk_retrieval(self, model, device, K):
        """Extract features from validation split and search on train split features."""
        n_data = len(self.train_dataloader.dataset)
        logging.info("n_data = %d" % n_data)
        feat_dim = self.feat_dim

        model.to(device)
        model.eval()

        if not self.is_feature_generated:
            self.train_features = torch.zeros([feat_dim, n_data], device=device)
            with torch.no_grad():
                for batch_idx, (inputs, _) in enumerate(self.train_dataloader):
                    inputs = inputs.to(device)
                    batch_size = inputs.size(0)

                    # forward
                    features = model(inputs)
                    features = nn.functional.normalize(features)
                    self.train_features[:, batch_idx * batch_size:batch_idx * batch_size + batch_size] = features.data.t()
                    logging.info("(train) batch_idx = %d/%d" % (batch_idx, len(self.train_dataloader)))
                self.train_labels = torch.LongTensor(self.train_dataloader.dataset.targets).to(device)
            self.is_feature_generated = True

        total = 0
        correct = 0
        logging.info("self.val_dataloader = %d" % len(self.val_dataloader))
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.val_dataloader):
                targets = targets.to(device)
                batch_size = inputs.size(0)
                features = model(inputs.to(device))

                dist = torch.mm(features, self.train_features)
                yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
                candidates = self.train_labels.view(1, -1).expand(batch_size, -1)
                retrieval = torch.gather(candidates, 1, yi)

                retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)

                total += targets.size(0)
                correct += retrieval.eq(targets.data).sum().item()
                logging.info("(val) batch_idx = %d/%d" % (batch_idx, len(self.val_dataloader)))
        top1 = correct / total
        return top1

    def eval(self, model, device, K):
        return self._topk_retrieval(model, device, K)
