import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        N, C, H, W = x.size() # read in N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image  
    

class Classifier(nn.Module):
    '''
    MNIST digit classifier.
    '''
    def __init__(self, input_channel=1, use_gpu=True, feature=False):
        super(Classifier, self).__init__()
        self.use_gpu = use_gpu
        self.feature_extract = feature
        self.feature = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0),)

            
        self.classify = nn.Sequential(

            nn.LeakyReLU(0.2, inplace=False),
            Flatten(),
            nn.Linear(128, 5))
            
        #                 nn.Conv2d(128, 10, kernel_size=1, stride=1, padding=0)
        if self.use_gpu:
            self.type(torch.cuda.FloatTensor)


    def forward(self, input):
        # TODO implement the forward pass
        if(self.feature_extract):
            return self.feature(input)
        return self.classify(self.feature(input))
    
    
class Classifier_fullsize(nn.Module):
    '''
    MNIST digit classifier.
    '''
    def __init__(self, input_channel=1, use_gpu=True, feature=False):
        super(Classifier_fullsize, self).__init__()
        self.use_gpu = use_gpu
        self.feature_extract = feature
        self.feature = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0),)

            
        self.classify = nn.Sequential(

            nn.LeakyReLU(0.2, inplace=False),
            Flatten(),
            nn.Linear(128, 10))
#             nn.Linear(30,10))
            
        #                 nn.Conv2d(128, 10, kernel_size=1, stride=1, padding=0)
        if self.use_gpu:
            self.type(torch.cuda.FloatTensor)


    def forward(self, input):
        # TODO implement the forward pass
        if(self.feature_extract):
            return self.feature(input)
        output = F.log_softmax(self.classify(self.feature(input)), dim=1)
        return output

class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1)
        self.average1 = nn.AvgPool2d(2, stride=2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1)
        self.average2 = nn.AvgPool2d(2, stride=2)
        self.conv3 = nn.Conv2d(16, 128, kernel_size=4, stride=1)
        
        self.flatten = Flatten()
        
        self.fc1 = nn.Linear(128, 10)
#         self.fc2 = nn.Linear(82,10)
    def forward(self, xb):
        xb = xb.view(-1, 1, 20, 20)
        xb = F.relu(self.conv1(xb))
        xb = self.average1(xb)
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = xb.view(-1, xb.shape[1])
        xb = self.fc1(xb)
#         xb = F.relu(self.fc2(xb))
        return xb
    
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor,
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

class Critic(nn.Module):
    '''
    Critic Class
    Values:
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, input_channel = 1, im_chan=1, hidden_dim=64, feature_extractor = None):
        super(Critic, self).__init__()
#         self.feature_extract = feature_extractor.to("cuda")
        self.feature = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 128, kernel_size=4, stride=1, padding=0),
          
        )
        self.crit = nn.Sequential(
           
#             nn.Flatten(),
#             nn.Conv1d(128, 64, 1, 1),
#             nn.Linear(256,128),
            nn.Linear(128,1),
#             nn.Linear(64,1),
        )

    def forward(self, image):
        '''
        Function for completing a forward pass of the critic: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
#         print(self.feature(image).to("cuda").shape)
#         print(self.crit2(self.feature(image).to("cuda")).shape)
#         print(self.crit)
        return self.crit(self.feature(image).squeeze()).squeeze()
    
    def freeze(self):
        for param in self.feature.parameters():
            param.requires_grad = False
    def unfreeze(self):
        for param in self.feature.parameters():
            param.requires_grad = True
    
    
def get_noise(n_samples, z_dim, device='gpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
      n_samples: the number of samples to generate, a scalar
      z_dim: the dimension of the noise vector, a scalar
      device: the device type
    '''
    return torch.randn(n_samples, z_dim, device=device)