import torch
from torch import nn
import numpy as np
from torch.autograd import Variable


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size

    def forward(self, input):
        return input.view(input.size(0), *self.size)


class VAEBase(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.no_batch_norm = True
        self.task = opt.task
        self.latent_shape = opt.z_dim
        self.n_categories = opt.n_categories
        self.dropout_rate = opt.dropout_rate

        self.n_pixels = int(np.prod(opt.image_dim))
        if opt.synthetic_generation:
            self.n_pixels += opt.n_categories

    def gen_x_architecture(self, input_dim, learn_covar=True):
        out_chans = 2 if learn_covar else 1
        if self.task == 'LendingClub':
            modules = [
                nn.Linear(input_dim, 500),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(500, 500),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(500, out_chans * self.n_pixels)
            ]
        else:
            modules = [
                nn.Linear(input_dim, 50),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(50, 150),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(150, 400),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(400, out_chans * self.n_pixels)
            ]
        return nn.Sequential(*modules)

    def posterior_architecture(self, output_dim, learn_covar=True):
        if learn_covar:
            output_dim *= 2
        input_dim = self.n_pixels
        
        if self.task == 'LendingClub':
            modules = [
                Flatten(),
                nn.Linear(input_dim, 500),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(500, 500),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(500, output_dim),
            ]
        else:
            modules = [
                Flatten(),
                nn.Linear(input_dim, 400),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(400, 150),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(150, 50),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(50, output_dim),
            ]
        return nn.Sequential(*modules)


class NetVAE(VAEBase):
    def __init__(self, opt):
        super().__init__(opt)

        learn_posterior_covar = (opt.posterior_std is None)

        self.posterior_r_net = self.posterior_architecture(output_dim=opt.rep_dim,
                                                           learn_covar=learn_posterior_covar)
        if not opt.diagonal_x_std:
            self.x_std = nn.Parameter(torch.FloatTensor([0.5]))

        self.gen_x_net = self.gen_x_architecture(input_dim=opt.rep_dim,
                                                 learn_covar=opt.diagonal_x_std)


class NetDPEncoder(VAEBase):
    def __init__(self, opt):
        super().__init__(opt)

        learn_posterior_covar = (opt.posterior_std is None)

        self.posterior_r_net = self.posterior_architecture(output_dim=opt.rep_dim,
                                                           learn_covar=learn_posterior_covar)


class NetDPDecoder(VAEBase):
    def __init__(self, opt):
        super().__init__(opt)

        if not opt.diagonal_x_std:
            self.x_std = Variable(torch.FloatTensor([0.5]), requires_grad=True) #nn.Parameter(torch.FloatTensor([0.5]))

        self.gen_x_net = self.gen_x_architecture(input_dim=opt.rep_dim,
                                                 learn_covar=opt.diagonal_x_std)


class ClassifierBase(nn.Module):
    def __init__(self, opt):
        super().__init__()
        if not opt.noise_features_directly:
            self.latent_shape = opt.z_dim
        self.feature_dim = opt.image_dim
        self.n_categories = opt.n_categories
        self.dropout_rate = opt.dropout_rate
        self.posterior_dim = opt.posterior_dim
        self.conv_classifier = opt.conv_classifier
        self.no_batch_norm = True

        if opt.conv_classifier:
            if opt.image_dim == (1, 28, 28):
                self.chans = [16, 32, 64, 128, 256]
                self.kernels = [9, 7, 7, 5, 4]
                self.stride = [1, 1, 1, 1, 1]
                self.padding = [0, 0, 0, 0, 0]
            else:
                raise NotImplementedError("Conv classifier not implemented for image dim {}".format(opt.image_dim))

    def _conv(self, in_channels, out_channels, kernel_size, stride=1, padding=0, final_layer=False):
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                         padding=padding)
        if final_layer:
            return nn.Sequential(conv)
        else:
            modules = [conv, nn.Dropout(p=self.dropout_rate)]
            if not self.no_batch_norm:
                modules += [nn.BatchNorm2d(out_channels)]
            modules += [nn.ReLU()]
            return nn.Sequential(*modules)

    def rep_classifier_architecture(self, input_dim, softmax=False):
        modules = [
            Flatten(),
            nn.Linear(input_dim, 50),
            nn.Dropout(p=self.dropout_rate),
            nn.ReLU(),
            nn.Linear(50, self.n_categories),
        ]
        if softmax:
            modules += [nn.Softmax(dim=1)]
        return nn.Sequential(*modules)

    def pixel_classifier_architecture(self, data_join_task, softmax=False, clean_join_data_dim=None):
        input_dim = int(np.prod(self.feature_dim))
        if data_join_task:
            input_dim += clean_join_data_dim

        if self.conv_classifier:
            modules = [
                self._conv(self.feature_dim[0], self.chans[0], self.kernels[0], self.stride[0], self.padding[0]),
                self._conv(self.chans[0], self.chans[1], self.kernels[1], self.stride[1], self.padding[1]),
                self._conv(self.chans[1], self.chans[2], self.kernels[2], self.stride[2], self.padding[2]),
                self._conv(self.chans[2], self.chans[3], self.kernels[3], self.stride[3], self.padding[3]),
                self._conv(self.chans[3], self.chans[4], self.kernels[4], self.stride[4], self.padding[4]),
                Flatten(),
                nn.Linear(self.chans[-1] * 1 * 1, self.n_categories),
            ]
        else:
            modules = [
                Flatten(),
                nn.Linear(input_dim, 400),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(400, 150),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(150, 50),
                nn.ReLU(),
                nn.Dropout(p=self.dropout_rate),
                nn.Linear(50, self.n_categories),
            ]
        if softmax:
            modules += [nn.Softmax(dim=1)]
        return nn.Sequential(*modules)


class NetClassifier(ClassifierBase):
    def __init__(self, opt):
        super().__init__(opt)
        self.softmax = opt.use_label_noise
        if opt.pixel_level:
            self.classifier_net = self.pixel_classifier_architecture(opt.data_join_task, 
                                softmax=self.softmax, clean_join_data_dim=opt.clean_join_data_dim)
        else:
            input_dim = opt.z_dim + opt.clean_join_data_dim if opt.data_join_task else opt.rep_dim
            self.classifier_net = self.rep_classifier_architecture(input_dim=input_dim, softmax=self.softmax)
