import torch
from torch import nn
from vit_pytorch import ViT
import torch.nn.functional as F
from multihead_attention import MultiHeadAttention
import math


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        self.args = args
        if args.dataset == 'cifar10':
            num_class = 10
        elif args.dataset == 'cifar100':
            num_class = 100
        elif args.dataset == 'svhn':
            num_class = 10
        elif args.dataset == 'imagenet':
            num_class = 1000

        self.encoder = ViT(
                            image_size = 32,
                            patch_size = args.vit_patch_size,
                            num_classes = num_class,
                            dim = 1024,
                            depth = args.depth,
                            heads = 16,
                            mlp_dim = 2048,
                            dropout = 0.1,
                            emb_dropout = 0.1
                            )

        self.patch_size = args.patch_size
        self.stride = self.patch_size // 2
        self.k = ((64 // self.patch_size) - 1)**2

        self.z_size_new = 3 * self.patch_size * self.patch_size

        if self.args.train_value == 1:
            self.z_memory_key_list = nn.Parameter(torch.randn(self.k, self.z_size_new),
                                                  requires_grad=True)
        else:
            self.z_memory_key_list = nn.Parameter(torch.randn(self.k, self.z_size_new) * args.std,
                                                  requires_grad=False)

        self.multihead_attention = MultiHeadAttention(args=args, in_features=self.z_size_new, head_num=32,
                                                      activation=None, seq_len=self.k)

        self.gamma = nn.Parameter(torch.ones(self.z_size_new))
        self.beta = nn.Parameter(torch.zeros(self.z_size_new))

        self.gamma_key = nn.Parameter(torch.ones(self.z_size_new))
        self.beta_key = nn.Parameter(torch.zeros(self.z_size_new))

    def forward(self, x, device):
        batch_size = x.shape[0]
        x_org = x

        u = self._divide_patch(x)
        u = self._z_meta_attention(u)
        u = u.transpose(1, 2)
        x = F.fold(u, x.shape[-2:], kernel_size=self.patch_size, stride=self.stride, padding=0)

        y_pred_linear = self.encoder(x)
        y_pred = y_pred_linear.argmax(1)
        return y_pred_linear, y_pred

    def _divide_patch(self, x):
        x = F.unfold(x, kernel_size=self.patch_size, stride=self.stride, padding=0)
        x = x.transpose(1, 2)
        return x

    def _z_meta_attention(self, z):
        """
        :param z: of size (batch, seq, A)
        :return: of size (batch, seq, A)
        """
        batch_size, seq_len, A = z.shape[0], z.shape[1], z.shape[2]

        if self.args.norm_type == 'contextnorm':
            z = self.apply_context_norm(z, gamma=self.gamma, beta=self.beta)

        M_key_batch = torch.stack([self.z_memory_key_list] * batch_size, dim=0)  # (batch, num_memory, A)
        z, _ = self.multihead_attention(query=z, key=z, value=M_key_batch)

        if self.args.norm_type == 'contextnorm':
            z = self.apply_context_norm(z, gamma=self.gamma_key, beta=self.beta_key)
        return z

    def apply_context_norm(self, z_seq, gamma, beta):
        eps = 1e-8
        z_mu = z_seq.mean(1)
        z_sigma = (z_seq.var(1) + eps).sqrt()
        z_seq = (z_seq - z_mu.unsqueeze(1)) / z_sigma.unsqueeze(1)
        z_seq = (z_seq * gamma) + beta
        return z_seq

