import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights
import ipdb
from copy import deepcopy


class ConvBase(nn.Module):
    def __init__(self, output_size, nc, hidden):
        super().__init__()

        filters = hidden
        filters = 32

        h_shape = 5

        self.f1 = nn.ModuleList()
        self.conv1 = nn.Conv2d(nc, filters, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(filters, filters, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(filters, filters, 3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(filters, filters, 3, stride=1, padding=1)

        self.fc = nn.Linear(h_shape * h_shape * filters, output_size)

        self.bn1 = nn.BatchNorm2d(filters, track_running_stats=False)
        self.bn2 = nn.BatchNorm2d(filters, track_running_stats=False)
        self.bn3 = nn.BatchNorm2d(filters, track_running_stats=False)
        self.bn4 = nn.BatchNorm2d(filters, track_running_stats=False)

        self.MP = nn.MaxPool2d(2)



    def forward(self, x):

        h = F.relu(self.MP(self.bn1(self.conv1(x))))
        h = F.relu(self.MP(self.bn2(self.conv2(h))))
        h = F.relu(self.MP(self.bn3(self.conv3(h))))
        h = F.relu(self.MP(self.bn4(self.conv4(h))))
        return torch.flatten(h, 1)


class PretrainedResnet(nn.Module):
    def __init__(self, input_size, z_size, hidden, output_size, n_layers):
        super().__init__()
        num_filters = 32
        self.feature_extractor = ConvBase(output_size, input_size[-1], num_filters)

        hidden = 64
        h_shape = 1 if input_size[0] == 28 else 5

        self.fc1 = nn.Sequential(
            nn.Linear(h_shape * h_shape * num_filters, 256),
            nn.ReLU(),
            nn.Linear(256, 128))
        self.head = CAVIALayer(128, z_size, hidden, output_size, 1)

    def forward(self, x, theta, l_start=None, l_finish=None):
        z = self.feature_extractor(x)
        z = self.fc1(z)
        return self.head(z, theta)



class base_learner(nn.Module):

    def __init__(self, input_size, z_size, hidden, output_size, n_layers, cnn, model_type, use_batch_norm):
        super().__init__()

        if cnn == 0:
            self.learner = MLP(input_size, z_size, hidden, output_size, n_layers)
        elif cnn == 1 and model_type in ['cavia', 'lava']:
            #self.learner = MiniOmniglotCNN(input_size, z_size, hidden, output_size, n_layers, use_batch_norm)
            self.learner = CNN(input_size, z_size, hidden, output_size, n_layers, use_batch_norm)
            #self.learner = CNNSymbol(input_size, z_size, hidden, output_size, n_layers, use_batch_norm)
        elif cnn == 1 and model_type in ['anil']:
            self.learner = PretrainedResnet(input_size, z_size, hidden, output_size, n_layers)

    def get_theta_shape(self):
        return self.learner.get_theta_shape()

    def get_theta_size_per_layer(self):
        return self.learner.get_theta_size_per_layer()

    def get_param_list(self, theta):
        return self.learner.get_param_list(theta)

    def forward(self, x, theta, l_start=0, l_finish=-1):
        return self.learner.forward(x, theta, l_start, l_finish)


class CAVIALayer(nn.Module):
    def __init__(self, input_size, z_size, hidden, output_size, n_layers):
        super().__init__()

        if n_layers == 1:
            self.net = nn.Sequential(nn.Linear(input_size + z_size, output_size))
        else:

            layers = []
            layers.append(nn.Linear(input_size + z_size, hidden))
            layers.append(nn.ReLU())
            for i in range(n_layers):
                layers.append(nn.Linear(hidden, hidden))
                layers.append(nn.ReLU())

            layers.append(nn.Linear(hidden, output_size))

            self.net = nn.Sequential(*layers)
        self.projection = nn.Linear(z_size, z_size)

    def forward(self, x, z):
        zc = self.projection(z)
        z = zc.repeat(x.shape[0], 1)
        xz = torch.cat((x, z), -1)
        return self.net(xz)


class MLP(nn.Module):

    def __init__(self, input_size, z_size, hidden, output_size, n_layers):
        super().__init__()

        self.n_layers = n_layers
        self.z_size = z_size

        self.f = nn.ModuleList()
        self.f.append(nn.Linear(input_size + z_size, hidden))
        self.f.append(nn.ReLU())
        for _ in range(self.n_layers - 2):
            self.f.append(nn.Linear(hidden, hidden))
            self.f.append(nn.ReLU())
        self.f.append(nn.Linear(hidden, output_size))

        self.theta_shapes = []
        for layer in self.f:
            for param in layer.parameters():
                self.theta_shapes.append(list(param.shape))

    def get_theta_shape(self):
        return self.theta_shapes

    def get_theta_size_per_layer(self):
        return [np.prod(shape) for shape in self.theta_shapes]

    def get_param_list(self, theta):
        layer_size = self.get_theta_size_per_layer()
        return [theta[0, int(np.sum(layer_size[:i])):int(np.sum(layer_size[:i]))+layer_size[i]].view(self.theta_shapes[i]) for i, s in enumerate(layer_size)]

    def forward(self, x, theta, l_start=0, l_finish=-1):

        with_last = False
        if l_finish == -1:
            l_finish = self.n_layers - 1
            with_last = True

        if self.z_size > 0:
            if l_start == 0:
                if x.shape[0] != theta.shape[0]:
                    theta = theta.repeat(x.shape[0], 1)
                h = torch.cat([x, theta], -1)
            else:
                h = x

            for i in range(l_start, l_finish):
                h = self.f[2*i+1](self.f[2*i](h))
            if with_last:
                h = self.f[-1](self.f[-2](h))
            return h

        h = x

        for i in range(l_start, l_finish):
            h = F.relu(F.linear(h, theta[2*i], bias=theta[2*i+1]))
        if with_last:
            out = F.linear(h, theta[-2], bias=theta[-1])
        else:
            out = h

        return out



class CNN(nn.Module):

    def __init__(self, input_size, z_size, hidden, output_size, n_layers, use_batchnorm):
        super().__init__()

        self.use_batchnorm = use_batchnorm
        self.n_layers = n_layers
        self.z_size = z_size

        filters = hidden

        h_shape = 1 if input_size[0] == 28 else 5

        if input_size[0] == 14:
            h_shape = 1
        elif input_size[0] == 28:
            h_shape = 1
        elif input_size[0] == 84:
            h_shape = 5
        elif input_size[0] == 60:
            h_shape = 3
        elif input_size[0] == 50:
            h_shape = 3
        elif input_size[0] == 40:
            h_shape = 2
        elif input_size[0] == 30:
            h_shape = 1
        elif input_size[0] == 20:
            h_shape = 1

        self.f1 = nn.ModuleList()
        self.f1.append(nn.Conv2d(input_size[1], filters, 3, stride=1, padding=1))
        self.f1.append(nn.Conv2d(filters, filters, 3, stride=1, padding=1))
        self.f1.append(nn.Conv2d(filters, filters, 3, stride=1, padding=1))

        self.bn1 = nn.BatchNorm2d(filters, track_running_stats=False) if self.use_batchnorm else nn.Identity()
        self.bn2 = nn.BatchNorm2d(filters, track_running_stats=False) if self.use_batchnorm else nn.Identity()
        self.bn3 = nn.BatchNorm2d(filters, track_running_stats=False) if self.use_batchnorm else nn.Identity()
        self.bn4 = nn.BatchNorm2d(filters, track_running_stats=False) if self.use_batchnorm else nn.Identity()

        self.f2 = nn.ModuleList()
        self.f2.append(nn.Conv2d(filters, filters, 3, stride=1, padding=1))
        self.f2.append(nn.Flatten())
        self.f2.append(nn.Linear(h_shape * h_shape * filters, output_size))

        self.MP = nn.MaxPool2d(2)

        self.film_layer = nn.Linear(z_size, filters*2)
        self.theta_shapes = []
        for layer in self.f1:
            for param in layer.parameters():
                self.theta_shapes.append(list(param.shape))
        for layer in self.f2:
            for param in layer.parameters():
                self.theta_shapes.append(list(param.shape))

        for m in self.f1:
            self.init_weights(m)

        for m in self.f2:
            self.init_weights(m)

        torch.nn.init.kaiming_uniform_(self.film_layer.weight, nonlinearity='linear')



    def get_theta_shape(self):
        return self.theta_shapes

    def get_theta_size_per_layer(self):
        return [np.prod(shape) for shape in self.theta_shapes]

    def get_param_list(self, theta):
        layer_size = self.get_theta_size_per_layer()
        return [theta[0, int(np.sum(layer_size[:i])):int(np.sum(layer_size[:i]))+layer_size[i]].view(self.theta_shapes[i]) for i, s in enumerate(layer_size)]

    def init_weights(self, m):
        if type(m) == nn.Conv2d:
            torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif type(m) == nn.Linear:
            torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='linear')

            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)


    def forward(self, x, theta, l_start=0, l_finish=-1):

        if self.z_size > 0:

            if x.shape[0] != theta.shape[0]:
                theta = theta.repeat(x.shape[0], 1)

            z = self.film_layer(theta)
            beta = torch.unsqueeze(torch.unsqueeze(z[:, :int(z.shape[-1] / 2)], -1), -1)
            gamma = torch.unsqueeze(torch.unsqueeze(z[:, int(z.shape[-1] / 2):], -1), -1)

            h = F.relu(self.MP(self.bn1(self.f1[0](x))))

            h = F.relu(self.MP(self.bn2(self.f1[1](h))))

            h = self.MP(self.bn3(self.f1[2](h)))

            h = F.relu(gamma * h + beta)

            h = F.relu(self.MP(self.bn4(self.f2[0](h))))

            h = self.f2[1](h)
            out = self.f2[2](h)

            return out

        h = self.MP(F.relu(F.conv2d(x, theta[0], bias=theta[1], stride=1, padding=1)))
        h = self.MP(F.relu(F.conv2d(h, theta[2], bias=theta[3], stride=1, padding=1)))
        h = self.MP(F.relu(F.conv2d(h, theta[4], bias=theta[5], stride=1, padding=1)))
        h = self.MP(F.relu(F.conv2d(h, theta[6], bias=theta[7], stride=1, padding=1)))
        h = torch.flatten(h, start_dim=1)
        out = F.linear(h, theta[8], bias=theta[9])

        return out
