import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.distributions import Beta
from numpy import linalg as LA
import numpy as np

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class Conv_Standard(nn.Module):
    def __init__(self, args, x_dim, hid_dim, z_dim, final_layer_size):
        super(Conv_Standard, self).__init__()
        self.args = args
        self.net = nn.Sequential(self.conv_block(x_dim, hid_dim), self.conv_block(hid_dim, hid_dim),
                                 self.conv_block(hid_dim, hid_dim), self.conv_block(hid_dim, z_dim))
        self.dist = Beta(torch.FloatTensor([2]), torch.FloatTensor([2]))
        self.hid_dim = hid_dim

        self.logits = nn.Linear(final_layer_size, self.args.num_classes)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # nn.Dropout(self.args.dropout)
        )

    def functional_conv_block(self, x, weights, biases,
                              bn_weights, bn_biases, dropout=0, is_training=False):

        x = F.conv2d(x, weights, biases, padding=1)
        x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.dropout(x, p=dropout, training=is_training)
        return x

    def forward(self, x):
        x = self.net(x)

        x = x.view(x.size(0), -1)

        return self.logits(x)

    def forward_anil(self, x, inin_weights, weights, dropout=0, is_training=False):
        # x = self.net(x)
        # # print(x.cpu().detach().numpy().shape)
        # x = x.view(x.size(0), -1)

        # x = F.linear(x, weights['weight'], weights['bias'])

        # return x

        for block in range(4):
            x = self.functional_conv_block(x, inin_weights[f'net.{block}.0.weight'], inin_weights[f'net.{block}.0.bias'],
                                           inin_weights.get(f'net.{block}.1.weight'), inin_weights.get(f'net.{block}.1.bias'), dropout, is_training)

        x = x.view(x.size(0), -1)

        x = F.linear(x, weights['weight'], weights['bias'])

        return x        

    def forward_maml(self, x, weights, dropout=0, is_training=False):
        for block in range(4):
            x = self.functional_conv_block(x, weights[f'net.{block}.0.weight'], weights[f'net.{block}.0.bias'],
                                           weights.get(f'net.{block}.1.weight'), weights.get(f'net.{block}.1.bias'), dropout, is_training)

        x = x.view(x.size(0), -1)

        x = F.linear(x, weights['logits.weight'], weights['logits.bias'])

        return x

    def kmeans_forward(self, x):
        kmenas_lists = []
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                data = x[i][j].cuda().expand(1, x.shape[2], x.shape[3], x.shape[4])
                data = self.net(data).detach().cpu()
                data = data.view(data.size(0), -1)
                #print(np.array(data).shape)
                vec = data / LA.norm(data)
                #print(np.array(vec.squeeze(0)).shape)
                kmenas_lists.append(vec.squeeze(0).tolist())
        return kmenas_lists
