# Cycle-GAN based aligner
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from datetime import datetime
from torch.utils.data import DataLoader

class Generator(nn.Module):
    def __init__(self, input_dim, low_dim, hidden_dim, drop_out):
        """
        input_dim: the number of input channels.
        hidden_dim: the number of neurons in the hidden layer.
        drop_out: drop-out rate.
        """
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.drop_out = drop_out
        self.low_dim = low_dim
        self.dim_reduction_layer = nn.Sequential(
            nn.Linear(self.input_dim, self.low_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.BatchNorm1d(self.low_dim)
        )
        self.model = nn.Sequential(
            nn.Linear(self.low_dim, self.hidden_dim),
            nn.Dropout(self.drop_out), 
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(self.drop_out), 
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.low_dim),
            nn.ReLU()
        )

    def forward(self, input, is_rec_dim=False):
        """
        input: spike firing rate data
        x: transformed spike firing rate data
        """
        x_latent = self.dim_reduction_layer(input) if is_rec_dim else input
        x = self.model(x_latent)
        return x

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, drop_out):
        """
        input_dim: the number of input channels.
        hidden_dim: the number of neurons in the hidden layer.
        drop_out: drop-out rate.
        """
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.drop_out = drop_out
        self.model = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1)
        )

    def forward(self, input):
        """
        input: spike firing rate data
        return: a label indicating if the input data is real or fake
        """
        label = self.model(input)
        return label