# Cycle-GAN based aligner
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from datetime import datetime
from torch.utils.data import DataLoader
from sklearn.metrics import r2_score
from decoder.wiener_filter import format_data_from_trials, formalize_spike_and_cursor, train_wiener_filter, test_wiener_filter, transform_lag_latent_feature

class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, drop_out):
        """
        input_dim: the number of input channels.
        hidden_dim: the number of neurons in the hidden layer.
        drop_out: drop-out rate.
        """
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.drop_out = drop_out
        self.model = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.Dropout(self.drop_out), 
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(self.drop_out), 
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.input_dim),
            nn.ReLU()
        )

    def forward(self, input):
        """
        input: spike firing rate data
        x: transformed spike firing rate data
        """
        x = self.model(input)
        return x

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, drop_out):
        """
        input_dim: the number of input channels.
        hidden_dim: the number of neurons in the hidden layer.
        drop_out: drop-out rate.
        """
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.drop_out = drop_out
        self.model = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1)
        )

    def forward(self, input):
        """
        input: spike firing rate data
        return: a label indicating if the input data is real or fake
        """
        label = self.model(input)
        return label
    
def train_cycle_gan_aligner(x1, x2, D_params, G_params, training_params, logs = False, cuda_device='0'):
    """
    x1: M1 spike firing rates on day-0. A list, where each item is a numpy array containing the neural data of one trial
    
    x2: M1 spike firing rates on day-k. A list, where each item is a numpy array containing the neural data of one trial
        x2 will be divided into two portions (ratio 3:1), where the first portion will be used to train the aligner, and 
        the second portion will be used as the validation set.
    
    y2: EMGs on day-k. A list, where each item is a numpy array containing the EMGs of one trial. Only a portion of y2
        (those corresponding to the trials used as the validation set) will be used.
    
    D_params: the hyper-parameters determining the structure of the discriminators, a dictionary.
    
    G_params: the hyper-parameters determining the structure of the generators, a dictionary.
    
    training_parameters: the hyper-parameters controlling the training process, a dictionary.
    
    decoder: the day-0 decoder to be tested on the validation set, an array.
    
    n_lags: the number of time lags of the decoder, a number.
    
    logs: to indicate if training logs is needed to be recorded as a .pkl file, a bool.
    
    return: a trained "aligner" (generator) for day-k use.
    """
    #============================================= Specifying hyper-parameters =============================================
    D_hidden_dim = D_params['hidden_dim']
    G_hidden_dim = G_params['hidden_dim']
    loss_type = training_params['loss_type']
    optim_type = training_params['optim_type']
    epochs = training_params['epochs']
    batch_size = training_params['batch_size']
    D_lr = training_params['D_lr']
    G_lr = training_params['G_lr']
    ID_loss_p = training_params['ID_loss_p']
    cycle_loss_p = training_params['cycle_loss_p']
    drop_out_D = training_params['drop_out_D']
    drop_out_G = training_params['drop_out_G']
    
    #============================================= Defining networks ===================================================
    x_dim = int(x1[0].shape[-2]*x1[0].shape[-1])
    generator1, generator2 = Generator(x_dim, G_hidden_dim, drop_out_G), Generator(x_dim, G_hidden_dim, drop_out_G)
    discriminator1, discriminator2 = Discriminator(x_dim, D_hidden_dim, drop_out_D), Discriminator(x_dim, D_hidden_dim, drop_out_D)

    device = torch.device("cuda:" + cuda_device) if torch.cuda.is_available() else torch.device('cpu')
    generator1.to(device), generator2.to(device)
    discriminator1.to(device), discriminator2.to(device)

    #==================================== Specifying the type of the losses ===============================================
    if loss_type == 'L1':
        criterion_GAN = torch.nn.MSELoss()
        criterion_cycle = torch.nn.L1Loss()
        criterion_identity = torch.nn.L1Loss()
    elif loss_type == 'MSE':
        criterion_GAN = torch.nn.MSELoss()
        criterion_cycle = torch.nn.MSELoss()
        criterion_identity = torch.nn.MSELoss()

    #====================================== Specifying the type of the optimizer ==============================================
    if optim_type == 'SGD':
        gen1_optim = optim.SGD(generator1.parameters(), lr = G_lr, momentum=0.9)
        gen2_optim = optim.SGD(generator2.parameters(), lr = G_lr, momentum=0.9)
        dis1_optim = optim.SGD(discriminator1.parameters(), lr = D_lr, momentum=0.9)
        dis2_optim = optim.SGD(discriminator2.parameters(), lr = D_lr, momentum=0.9)
    elif optim_type == 'Adam':
        gen1_optim = optim.Adam(generator1.parameters(), lr = G_lr)
        gen2_optim = optim.Adam(generator2.parameters(), lr = G_lr)
        dis1_optim = optim.Adam(discriminator1.parameters(), lr = D_lr)
        dis2_optim = optim.Adam(discriminator2.parameters(), lr = D_lr)
    elif optim_type == 'RMSProp':
        gen1_optim = optim.RMSprop(generator1.parameters(), lr = G_lr)
        gen2_optim = optim.RMSprop(generator2.parameters(), lr = G_lr)
        dis1_optim = optim.RMSprop(discriminator1.parameters(), lr = D_lr)
        dis2_optim = optim.RMSprop(discriminator2.parameters(), lr = D_lr)

    #=============================== Split x2 into the actual training set and the validation set ==============================
    #----------- x2_train will be used in Cycle-GAN training -------------
    # x2_train = x2[:int(len(x2)*0.75)] # training set
    x2_train = x2
    
    #------- x2_valid and y2_valid will be isolated from training, and used to test the performance of the aligner every 10 trials
    # x2_valid, y2_valid = x2[:int(len(x2)*0.75)], y2[:int(len(x2)*0.75)]
    # x2_valid, y2_valid = x2[int(len(x2)*0.75):], y2[int(len(x2)*0.75):] # validation set
    
    #================================================  Define data Loaders ====================================================== 
    # x1, x2_train = np.concatenate(x1), np.concatenate(x2_train)
    x1, x2_train = [np.reshape(x1_trial, (x1_trial.shape[0], -1)) for x1_trial in x1], [np.reshape(x2_train_trial, (x2_train_trial.shape[0], -1)) for x2_train_trial in x2_train]
    x1, x2_train = np.concatenate(x1), np.concatenate(x2_train)

    #--------------- loader1 is for day-0 data ---------------------
    loader1 = DataLoader(torch.utils.data.TensorDataset(torch.Tensor(x1).to(device)), batch_size = batch_size, shuffle = True)
    #--------------- loader2 is for day-k data in the training set ---------------------
    loader2 = DataLoader(torch.utils.data.TensorDataset(torch.Tensor(x2_train).to(device)), batch_size = batch_size, shuffle = True)
    
    #============================================ Training logs =========================================================
    train_log = {'epoch':[], 'batch_idx': [],
                 'loss D1':[], 'loss D2':[], 
                 'loss G1':[], 'loss G2':[],
                 'loss cycle 121':[], 'loss cycle 212':[],
                 'decoder r2 wiener': [],
                 'decoder r2 rnn': []}
    
    #============================================ Preparing to train ========================================================
    generator1.train()
    generator2.train()
    discriminator1.train()
    discriminator2.train()
    aligner_list = []
    mr2_all_list = []

    #================================================== The training loop ====================================================
    loss_delta, loss_base = 3e-3, 0.25
    # Chewie: 30 (epoch_limit), 3e-3 (loss_delta)
    epoch_cnt, epoch_limit = 0, 30*max(len(loader1),len(loader2))
    for epoch in range(epochs):
        for batch_idx, (data1_, data2_) in enumerate(zip(loader1, loader2)):
            #========================= loader1 and loader2 will yield mini-batches of data when running =========================
            #------ The batches by loader1 will be stored in data1, while the batches by loader2 will be stored in data2 ------
            data1, data2 = data1_[0], data2_[0]
            if data1.__len__() != data2.__len__():
                continue
            #------------ The labels for real samples --------------
            target_real = torch.ones((data1.shape[0], 1), requires_grad = False).type('torch.FloatTensor')
            #------------ The labels for fake samples --------------
            target_fake = torch.zeros((data1.shape[0], 1), requires_grad = False).type('torch.FloatTensor')

            target_real, target_fake = target_real.to(device), target_fake.to(device)

            #================================================== Generators ==================================================
            gen1_optim.zero_grad()
            gen2_optim.zero_grad()
            
            #------------ Identity loss, to make sure the generators do not distort the inputs --------------
            same2 = generator1(data2)
            loss_identity2 = criterion_identity(same2, data2)*ID_loss_p
            same1 = generator2(data1)
            loss_identity1 = criterion_identity(same1, data1)*ID_loss_p
            
            #------------ GAN loss for generator1, see the figure right above --------------
            fake2 = generator1(data1)
            pred_fake = discriminator2(fake2)
            loss_GAN2 = criterion_GAN(pred_fake, target_real)
            
            #------------ GAN loss for generator2, see the figure right above --------------
            fake1 = generator2(data2)
            pred_fake = discriminator1(fake1)
            loss_GAN1 = criterion_GAN(pred_fake, target_real)
            
            #------------ Cycle loss, see the figure right above --------------
            recovered1 = generator2(fake2)
            loss_cycle_121 = criterion_cycle(recovered1, data1)*cycle_loss_p
            
            recovered2 = generator1(fake1)
            loss_cycle_212 = criterion_cycle(recovered2, data2)*cycle_loss_p
            
            #----------- Total loss of G, the sum of all the losses defined above -----------
            loss_G = loss_identity1 + loss_identity2 + loss_GAN1 + loss_GAN2 + loss_cycle_121 + loss_cycle_212
            
            #-------- Backward() and step() for generators ---------
            loss_G.backward() 
            gen1_optim.step()
            gen2_optim.step()
            
            #================================================== Discriminator 1 ==================================================
            dis1_optim.zero_grad()
            
            #-------------- Adversarial loss from discriminator 1, see the figure above ------------------
            pred_real = discriminator1(data1)
            loss_D_real = criterion_GAN(pred_real, target_real)
            
            pred_fake = discriminator1(generator2(data2).detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            
            loss_D1 = (loss_D_real + loss_D_fake)/2
            
            #-------- Backward() and step() for discriminator1 ---------
            loss_D1.backward()
            dis1_optim.step()
            
            #-------------- Adversarial loss from discriminator 2, see the figure above ------------------
            dis2_optim.zero_grad()
            
            pred_real = discriminator2(data2)
            loss_D_real = criterion_GAN(pred_real, target_real)
            
            pred_fake = discriminator2(generator1(data1).detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            
            loss_D2 = (loss_D_real + loss_D_fake)/2
            
            #-------- Backward() and step() for discriminator2 ---------
            loss_D2.backward()
            dis2_optim.step()
            
            print("Epoch %d / Batch %d: loss_G = %f, loss_D1 = %f, loss_D2 = %f"%(epoch, batch_idx, loss_G, loss_D1, loss_D2))

            #============================ early stop ==============================
            if abs(loss_D1.item()-loss_base) < loss_delta and abs(loss_D2.item()-loss_base) < loss_delta:
                epoch_cnt += 1
          
            #====================================== save the training logs ========================================
            if logs == True:
                train_log['epoch'].append(epoch)
                train_log['batch_idx'].append(batch_idx)
                train_log['loss D1'].append(loss_D1.item())
                train_log['loss D2'].append(loss_D2.item())
                train_log['loss G1'].append(loss_GAN1.item())
                train_log['loss G2'].append(loss_GAN2.item())
                train_log['loss cycle 121'].append(loss_cycle_121.item())
                train_log['loss cycle 212'].append(loss_cycle_212.item())
        if epoch_cnt >= epoch_limit:
            break
                
        #================ Test the aligner every 10 epoches on the validation set ====================
        if (epoch + 1) % 10 == 0:
            '''
            #---------- Put generator2, namely the aligner, into evaluation mode ------------
            generator2.eval()
            
            #---------- Use the trained aligner to transform the trials in x2_valid -----------
            x2_valid_aligned = []
            with torch.no_grad():  
                for each in x2_valid:
                    data = torch.from_numpy(each).type('torch.FloatTensor')
                    x2_valid_aligned.append(generator2(data).numpy())
            
            #--------- Feed the day-0 decoder with x2_valid_aligned to evaluate the performance of the aligner ----------
            if not is_cursor:
                x2_valid_aligned_, y2_valid_ = format_data_from_trials(x2_valid_aligned, y2_valid, n_lags)
                pred_y2_valid_ = test_wiener_filter(x2_valid_aligned_, decoder)
            else:
                x2_valid_aligned_, y2_valid_ = formalize_spike_and_cursor(x2_valid_aligned, y2_valid)
                x2_valid_aligned_latent_ = transform_lag_latent_feature(x2_valid_aligned_, n_lags)
                pred_y2_valid_ = test_wiener_filter(x2_valid_aligned_latent_, decoder)
            
            #--------- Compute the multi-variate R2 between pred_y2_valid (predicted EMGs) and y2_valid (real EMGs) ----------
            mr2 = r2_score(y2_valid_, pred_y2_valid_, multioutput='variance_weighted')
            print('On the %dth epoch, the R\u00b2 on the validation set is %.2f'%(epoch+1, mr2))
            '''

            #------- Save the half-trained aligners and the corresponding performance on the validation set ---------
            # generator2: transform src to tgt spikes
            aligner_list.append(generator2)
            # mr2_all_list.append(mr2)
            
            #---------- Put generator2 back into training mode after finishing the evaluation -----------
            generator2.train()
    
    # IDX = np.argmax(mr2_all_list) 
    # print('The aligner has been well trained on the %dth epoch'%(IDX*10))
    train_log['decoder r2 wiener'] = mr2_all_list
    #============================================ save the training log =================================================
    if logs == True:
        dt_string = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
        with open('./train_logs/train_log_' + dt_string + '.pkl', 'wb') as fp:
            pickle.dump(train_log, fp)        
    
    return aligner_list[-1]

def test_cycle_gan_aligner(net, dayk_data, cuda_device='0'):
    """
    net: the trained aligner
    dayk_data: the data that needs to be processed by the trained aligner
    """
    device = torch.device("cuda:" + cuda_device) if torch.cuda.is_available() else torch.device('cpu')

    #------ Put the net in eval mode ------ #
    aligner = net.eval()
    # dayk_aligned = []
    
    #------ Use the trained aligner to process the dayk_data ------#
    with torch.no_grad(): 
        '''
        for each in dayk_data:
            data_tensor = torch.from_numpy(each).type('torch.FloatTensor').to(device)
            data_tensor_align = aligner(data_tensor).numpy() if device == 'cpu' else aligner(data_tensor).detach().cpu().numpy()
            dayk_aligned.append(data_tensor_align)
        '''
        dayk_data_tmp = torch.reshape(dayk_data, (dayk_data.shape[0], -1))
        dayk_data_aligned = aligner(dayk_data_tmp)
        dayk_data_aligned = torch.reshape(dayk_data_aligned, (dayk_data.shape[0], dayk_data.shape[1], dayk_data.shape[-1]))
    
    #------ Return the aligned day-k data --------#
    return dayk_data_aligned