# -*- coding: UTF-8 -*-

import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from .util import LinearWarmupCosineAnnealingLR

from .model.vit.idi_vit import IDI_ViT_B32
from .model.vit.vit import ViT_B32
from .model.cnn.resnet_timm import resnet50, resnet152
from .model.cnn.seresnet_timm import legacy_seresnet50 as seresnet50
from .model.cnn.idi_resnet_timm import idi_resnet50, idi_resnet152
from .model.cnn.idi_seresnet_timm import idi_legacy_seresnet50 as idi_seresnet50

from .datasets.imagenet224 import ImageNet224Module

from timm.data import Mixup


Models = dict(
IDI_ViT_B32=IDI_ViT_B32,
ViT_B32=ViT_B32,
resnet50=resnet50,
idi_resnet50=idi_resnet50,
seresnet50=seresnet50,
idi_seresnet50=idi_seresnet50,
resnet152=resnet152,
idi_resnet152=idi_resnet152,
)


class ImageNet224Interface(ImageNet224Module):
    def __init__(self, arch, data_path, batch_size, workers,
                 weight_decay, momentum, num_classes, lr, max_epochs, weight_times, **kwargs):
        super().__init__(workers=workers, batch_size=batch_size, data_path=data_path)

        self.weight_decay = weight_decay
        self.num_classes = num_classes
        self.lr = lr
        self.max_epochs = max_epochs
        self.momentum = momentum
        self.weight_times = weight_times

        self.mixup = Mixup(cutmix_alpha=1.0, mixup_alpha=0.8, switch_prob=0.5, label_smoothing=0.1)

        for i in range(self.weight_times):
            self.model = Models[arch](num_classes=self.num_classes)

        self.save_hyperparameters(ignore=kwargs)  # save parameters in __init__ function

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        data, target = batch
        data, target_aug = self.mixup(data, target)

        output_logits = self(data)
        loss_train = F.cross_entropy(output_logits, target_aug)
        acc1, acc5 = self.__accuracy(output_logits, target, topk=(1, 5))
        self.log('train/loss', loss_train, on_epoch=True, logger=True)
        self.log('train/acc1', acc1, prog_bar=True, on_epoch=True)
        self.log('train/acc5', acc5, on_epoch=True)
        return loss_train

    def training_epoch_end(self, outputs):
        if self.current_epoch == 35:
            self.size = 224

    def validation_step(self, batch, batch_idx):
        data, target = batch

        output_logits = self(data)
        loss_val = F.cross_entropy(output_logits, target)
        acc1, acc5 = self.__accuracy(output_logits, target, topk=(1, 5))
        self.log('val/loss', loss_val, on_epoch=True)
        self.log('val/acc1', acc1, prog_bar=True, on_epoch=True)
        self.log('val/acc5', acc5, on_epoch=True)

    @staticmethod
    def __accuracy(output, target, topk=(1, )):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True)
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=int(0.1*self.max_epochs), max_epochs=self.max_epochs, warmup_start_lr=1e-6)
        return [optimizer], [scheduler]

    def test_step(self, *args, **kwargs):
        return self.validation_step(*args, **kwargs)

    def test_epoch_end(self, *args, **kwargs):
        outputs = self.validation_epoch_end(*args, **kwargs)

        def substitute_val_keys(out):
            return {k.replace('val', 'test'): v for k, v in out.items()}

        outputs = {
            'test_loss': outputs['val_loss'],
            'progress_bar': substitute_val_keys(outputs['progress_bar']),
            'log': substitute_val_keys(outputs['log']),
        }
        return outputs

    def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
        self.log('lr', optimizer.param_groups[0]['lr'], on_epoch=True)

    @staticmethod
    def get_model_info(args):
        model_info_str = ""

        return model_info_str

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        parser = parent_parser.add_argument_group("ModelInterface")

        parser.add_argument("--arch", type=str, default="idi_resnet152")
        parser.add_argument("--data_name", type=str, default='imagenet')
        parser.add_argument("--data_path", type=str, default='./imagenet_data')
        parser.add_argument("--workers", default=8, type=int, help="number of data loading workers (default: 8)")
        parser.add_argument("--batch_size", default=256, type=int)
        parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
        parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
        parser.add_argument("--weight_decay", default=1e-4, type=float)
        # parser.add_argument("--is_pretrained", action="store_true", default=False)
        parser.add_argument("--num_classes", default=1000, type=int)
        parser.add_argument("--weight_times", type=int, default=1)

        return parent_parser


