"""
    ref: https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import make_layers, cfgs, VGG
from typing import Union, List, Dict, Any, cast

import math
import numpy as np
from functools import reduce
from itertools import permutations
from copy import deepcopy
from collections import Counter
from utils import *

class VGGTiny(VGG):
    def __init__(self, n_dims=None, ratio=None, **kwargs):
        super(VGGTiny, self).__init__(**kwargs)
        self.n_dims = self.features[-4].out_channels if n_dims is None else n_dims
        self.classifier = nn.Sequential(
            nn.Linear(n_dims, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, kwargs['num_classes'])
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x


class VGGTinyManifoldMixup(VGG):
    def __init__(self, n_dims=None, ratio=None, **kwargs):
        super(VGGTinyManifoldMixup, self).__init__(**kwargs)
        self.n_dims = self.features[-4].out_channels if n_dims is None else n_dims
        self.classifier = nn.Sequential(
            nn.Linear(n_dims, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, kwargs['num_classes'])
        )

    def forward(self, x, label=None, rank=None):
        x = self.features(x)
        feature_vector = torch.flatten(x, 1)
        logits = self.classifier(feature_vector)
        if label == None:
            return logits

        out, y_a, y_b, lam = mixup_data(feature_vector, label, alpha=2.0)
        out = self.classifier(out)
        lam = torch.tensor(lam).cuda(rank)
        return out, y_a, y_b, lam


def _vgg_tiny(arch, cfg, batch_norm, **kwargs):
    return VGGTiny(features=make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)

def _vgg_tiny_manifold_mixup(cfg, batch_norm, **kwargs):
    return VGGTinyManifoldMixup(features=make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)

def Vgg11Tiny(**kwargs):
    return _vgg_tiny('Vgg11Tiny', 'A', True, **kwargs)


def Vgg11TinyManifoldMixup(**kwargs):
    return _vgg_tiny_manifold_mixup('A', True, **kwargs)

class VGGTinyAMA(VGG):
    def __init__(self, n_dims=None, beta=2./3., **kwargs):
        super(VGGTinyAMA, self).__init__(**kwargs)
        self.classifier = nn.Sequential(
            nn.Linear(n_dims, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, kwargs['num_classes'])
        )
        self.beta = beta

    def forward(self, x, tr_acc=None, label=None):
        feature_vector = self.features(x)
        feature_vector = feature_vector.view(feature_vector.size(0), -1)
        logits = self.classifier(feature_vector)
        if label is None:
            return logits

        output = logits
        labels = label
        p = math.exp(-self.beta * tr_acc)
        virtual_samples, src_labels, tgt_labels = self.create_virtual_sample_balance(feature_vector, label, p)
        if len(virtual_samples) > 0:
            virtual_logits = self.classifier(virtual_samples)
            output = torch.vstack((logits, virtual_logits))
            labels = torch.cat((label, src_labels))
        return output, labels

    def create_virtual_sample_balance(self, feature_vector, label, p):
        virtual_samples, virtual_labels = [], []
        src_labels, tgt_labels = [], []
        label_np = label.detach().cpu().numpy()
        counter = Counter(label_np)
        prob = {k: 1. / (len(counter) * v) for k, v in counter.items()}
        prob = [prob[v] for v in label_np]
        bsz = len(feature_vector)
        indices = np.random.choice(bsz, (bsz, 2), p=prob)
        indices = np.array([[i, j] for i, j in indices if i != j])
        if len(indices) > 0:
            virtual_samples = feature_vector[indices[:, 0]] * p + \
                              feature_vector[indices[:, 1]] * (1. - p)
            src_labels = label[indices[:, 0]]
            tgt_labels = label[indices[:, 1]]

        return virtual_samples, src_labels, tgt_labels


def _vgg_tiny_AMA(cfg, batch_norm, beta, **kwargs):
    return VGGTinyAMA(features=make_layers(cfgs[cfg], batch_norm=batch_norm), beta=beta, **kwargs)


def Vgg11TinyAMA(beta, **kwargs):
    return _vgg_tiny_AMA('A', True, beta=beta, **kwargs)

