import argparse
import csv
import cv2
import torch
import time
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
from lyapunov_loss import LyapunovLoss
from torch.utils.data.sampler import SubsetRandomSampler

def data_preparation():
    num_workers = 0
    # how many samples per batch to load
    batch_size = 20
    # percentage of training set to use as validation
    valid_size = 0.2
    # convert data to torch.FloatTensor
    transform = transforms.Compose([transforms.ToTensor()])
    # choose the training and testing datasets

    f1 = open('../data/imdb_crop/imdb_age.csv')
    f2 = open('../data/imdb_crop/imdb_path.csv')
    reader_path = csv.reader(f2)
    reader_age  = csv.reader(f1)
    train_data  = []
    val_data    = []
    test_data   = []
    counter     = 0
    for (impath, age) in zip(reader_path, reader_age):
        if counter == 0:
            counter +=1
            continue
        print (counter)
        print (impath)
        img = cv2.imread('../data/imdb_crop/' + impath[0], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (32, 32))
        if np.amax(img) > 1:
            img = img / 255.0
        if counter <= 20000:
            train_data.append((transform(img), float(age[0])/100))
        if counter > 20000 and counter <= 24000:
            val_data.append((transform(img), float(age[0])/100))
        if counter > 24000 and counter <= 28000:
            test_data.append((transform(img), float(age[0])/100))
        counter += 1
        if counter == 28001:
            break

    num_train = len(train_data)
    print (num_train)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size,
                                               num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(val_data, batch_size = batch_size,
                                               num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size,
                                         num_workers = num_workers)
    return train_loader, valid_loader, test_loader

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # number of hidden nodes in each layer (512)
        hidden_1 = 512
        hidden_2 = 512
        # linear layer (784 -> hidden_1)
        self.fc1 = nn.Linear(32*32, 512)
        # linear layer (n_hidden -> hidden_2)
        self.fc2 = nn.Linear(512,512)
        # linear layer (n_hidden -> 10)
        self.fc3 = nn.Linear(512,1)
        # dropout layer (p=0.2)
        # dropout prevents overfitting of data
        self.droput = nn.Dropout(0.2)
        self.sigmoid = nn.LogSigmoid()

    def forward(self,x):
        # flatten image input
        x = x.view(-1,32*32)
        # add hidden layer, with relu activation function
        x = F.relu(self.fc1(x))
        # add dropout layer
        x = self.droput(x)
         # add hidden layer, with relu activation function
        x = F.relu(self.fc2(x))
        # add dropout layer
        x = self.droput(x)
        # add output layer
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

def initialize(type_loss, alpha=0.8):
    model = Net()
    print(model)
    if type_loss == "l1":
        criterion = nn.L1Loss()
    if type_loss == "l2":
        criterion = nn.MSELoss()
    if type_loss == "lyapunov":
        criterion = LyapunovLoss(alpha)
    # specify optimizer (stochastic gradient descent) and learning rate = 0.01
    optimizer = torch.optim.SGD(model.parameters(),lr = 0.0001)
    return model, criterion, optimizer

def test(model, criterion, test_loader, type_loss, gpu):
    if type_loss == "l1":
        model.load_state_dict(torch.load('model_l1.pt'))
    if type_loss == "l2":
        model.load_state_dict(torch.load('model_l2.pt'))
    if type_loss == "lyapunov":
        model.load_state_dict(torch.load('model_lyapunov.pt'))
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    model.eval() # prep model for evaluation
    for data, target in test_loader:
        # forward pass: compute predicted outputs by passing inputs to the model
        data = data.float()
        if gpu:
            data = data.to('cuda')
            target = target.to('cuda')
        target = target.float()
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        # update test loss
        test_loss += loss.item()*data.size(0)
     # calculate and print avg test loss
    test_loss = test_loss/len(test_loader.sampler)
    if type_loss == 'l1':
        print('Test L1 Loss: {:.6f}\n'.format(test_loss))
    if type_loss == 'l2':
        print('Test L2 Loss: {:.6f}\n'.format(test_loss))
    if type_loss == 'lyapunov':
        print('Test Lyapunov Loss: {:.6f}\n'.format(test_loss))

def train(train_loader, valid_loader, test_loader, model_1, criterion_1,
          optimizer_1, model_2, criterion_2, optimizer_2, model_3, criterion_3,
          optimizer_3, alpha, gpu):
    # number of epochs to train the model
    n_epochs = 800
    # initialize tracker for minimum validation loss
    valid_loss1_min = np.Inf  
    valid_loss2_min = np.Inf
    valid_loss3_min = np.Inf
    train_loss1_list = []
    train_loss1_time_list = []
    valid_loss1_list = []
    valid_loss1_time_list = []
    train_loss2_list = []
    train_loss2_time_list = []
    valid_loss2_list = []
    valid_loss2_time_list = []
    train_loss3_list = []
    train_loss3_time_list = []
    valid_loss3_list = []
    valid_loss3_time_list = []
    cumulative_train_time1 = 0
    cumulative_train_time2 = 0
    cumulative_train_time3 = 0
    for epoch in range(n_epochs):
        # monitor losses
        train_loss1 = 0
        valid_loss1 = 0
        train_loss2 = 0
        valid_loss2 = 0
        train_loss3 = 0
        valid_loss3 = 0
        # train the model #
        model_1.train() # prep model for training
        model_2.train()
        model_3.train()
        for data,label in train_loader:
            # clear the gradients of all optimized variables
            data = data.float()
            if gpu:
                data = data.to('cuda')
                label = label.to('cuda')
            label = label.float()
            optimizer_1.zero_grad()
            optimizer_2.zero_grad()
            optimizer_3.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            start_train_time1 = time.clock()
            output_1 = model_1(data)
            # calculate the loss
            loss_1 = criterion_1(output_1,label)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss_1.backward()
            # perform a single optimization step (parameter update)
            optimizer_1.step()
            end_train_time1 = time.clock()
            cumulative_train_time1 += end_train_time1 - start_train_time1
            # forward pass: compute predicted outputs by passing inputs to the model
            start_train_time2 = time.clock()
            output_2 = model_2(data)
            # calculate the loss
            loss_2 = criterion_2(output_2,label)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss_2.backward()
            # perform a single optimization step (parameter update)
            optimizer_2.step()
            end_train_time2 = time.clock()
            cumulative_train_time2 += end_train_time2 - start_train_time2
            # forward pass: compute predicted outputs by passing inputs to the model
            start_train_time3 = time.clock()
            output_3 = model_3(data)
            # calculate the loss
            loss_3 = criterion_3(output_3,label)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss_3.backward()
            # de/dt = de/dw * dw/dt
            for name, param in model_3.named_parameters():
                if 'weight' in name:
                    abs_grad = torch.abs(param.grad)
                    pow_grad = torch.pow(abs_grad, alpha)
                    sign_grad = torch.sign(param.grad)
                    beta = 0.1
                    E_beta = torch.pow(loss_3, beta)
                    param.grad = abs_grad * pow_grad * sign_grad * E_beta
            # perform a single optimization step (parameter update)
            optimizer_3.step()
            end_train_time3 = time.clock()
            cumulative_train_time3 += end_train_time3 - start_train_time3
            # update running training loss
            train_loss1 += loss_1.item() * data.size(0)
            train_loss2 += loss_2.item() * data.size(0)
            train_loss3 += loss_3.item() * data.size(0)

        # validate the model #
        model_1.eval()  # prep model for evaluation
        model_2.eval()
        model_3.eval()
        for data,label in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            data = data.float()
            if gpu:
                data = data.to('cuda')
                label = label.to('cuda')

            label = label.float()
            output_1 = model_1(data)
            # calculate the loss
            loss_1 = criterion_1(output_1,label)
            # update running validation loss
            valid_loss1 += loss_1.item() * data.size(0)
            output_2 = model_2(data)
            # calculate the loss
            loss_2 = criterion_1(output_2,label)
            # update running validation loss
            valid_loss2 += loss_2.item() * data.size(0)
            output_3 = model_3(data)
            # calculate the loss
            loss_3 = criterion_3(output_3,label)
            # update running validation loss
            valid_loss3 += loss_3.item() * data.size(0)

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss1 = train_loss1 / len(train_loader)
        valid_loss1 = valid_loss1 / len(valid_loader)
        train_loss2 = train_loss2 / len(train_loader)
        valid_loss2 = valid_loss2 / len(valid_loader)
        train_loss3 = train_loss3 / len(train_loader)
        valid_loss3 = valid_loss3 / len(valid_loader)
        print('Epoch: {} \tTraining L1 Loss: {:.6f} \tValidation L1 Loss: {:.6f}'.format(
            epoch+1,
            train_loss1,
            valid_loss1
            ))
        print('Epoch: {} \tTraining L2 Loss: {:.6f} \tValidation L2 Loss: {:.6f}'.format(
              epoch+1,
              train_loss2,
              valid_loss2
            ))
        print('Epoch: {} \tTraining Lyapunov Loss: {:.6f} \tValidation Lyapunov Loss: {:.6f}'.format(
              epoch+1,
              train_loss3,
              valid_loss3
            ))

        train_loss1_list.append([train_loss1])
        valid_loss1_list.append([valid_loss1])
        train_loss1_time_list.append([float(cumulative_train_time1)])
        train_loss2_list.append([train_loss2])
        valid_loss2_list.append([valid_loss2])
        train_loss2_time_list.append([float(cumulative_train_time2)])
        train_loss3_list.append([train_loss3])
        valid_loss3_list.append([valid_loss3])
        train_loss3_time_list.append([float(cumulative_train_time3)])
        # save model if validation loss has decreased
        if valid_loss1 <= valid_loss1_min:
            print('Validation l1 loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss1_min,
            valid_loss1))
            torch.save(model_1.state_dict(), 'model_l1.pt')
            valid_loss1_min = valid_loss1
        if (epoch % 20 == 0):
            test(model_1, criterion_1, test_loader, "l1", gpu)
            pd.DataFrame(train_loss1_list).to_csv('train_l1_' + str(epoch)+'.csv')
            pd.DataFrame(train_loss1_time_list).to_csv('train_l1_time_' + str(epoch)+'.csv')
            pd.DataFrame(valid_loss1_list).to_csv('valid_l1_' + str(epoch)+'.csv')

        # save model if validation loss has decreased
        if valid_loss2 <= valid_loss2_min:
            print('Validation l2 loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                  valid_loss2_min,
                  valid_loss2))
            torch.save(model_2.state_dict(), 'model_l2.pt')
            valid_loss2_min = valid_loss2
        if (epoch % 20 == 0):
            test(model_2, criterion_2, test_loader, "l2", gpu)
            pd.DataFrame(train_loss2_list).to_csv('train_l2_' + str(epoch)+'.csv')
            pd.DataFrame(train_loss2_time_list).to_csv('train_l2_time_' + str(epoch)+'.csv')
            pd.DataFrame(valid_loss2_list).to_csv('valid_l2_' + str(epoch)+'.csv')

        # save model if validation loss has decreased
        if valid_loss3 <= valid_loss3_min:
            print('Validation lyapunov loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                  valid_loss3_min,
                  valid_loss3))
            torch.save(model_3.state_dict(), 'model_lyapunov.pt')
            valid_loss3_min = valid_loss3
        if (epoch % 20 == 0):
            test(model_3, criterion_3, test_loader, "lyapunov", gpu)
            pd.DataFrame(train_loss3_list).to_csv('train_lyapunov_' + str(epoch)+'.csv')
            pd.DataFrame(train_loss3_time_list).to_csv('train_lyapunov_time_' + str(epoch)+'.csv')
            pd.DataFrame(valid_loss3_list).to_csv('valid_lyapunov_' + str(epoch)+'.csv')


def main():
    parser = argparse.ArgumentParser(description='MLP IMDB Experiment!')
    parser.add_argument('--gpu', action='store_true')
    args = parser.parse_args() 
    print (args.gpu) 
    train_loader, valid_loader, test_loader = data_preparation()
    model_1, criterion_1, optimizer_1 = initialize("l1")
    model_2, criterion_2, optimizer_2 = initialize("l2")
    model_3, criterion_3, optimizer_3 = initialize("lyapunov", alpha=0.8)
    if args.gpu:
        model_1 = model_1.cuda()
        criterion_1 = criterion_1.cuda()
        model_2 = model_2.cuda()
        criterion_2 = criterion_2.cuda()
        model_3 = model_3.cuda()
        criterion_3 = criterion_3.cuda()

    train(train_loader, valid_loader, test_loader, model_1,
          criterion_1, optimizer_1, model_2, criterion_2,
          optimizer_2, model_3, criterion_3, optimizer_3, alpha=0.8,
          gpu=args.gpu)

    test(model_1, criterion_1, test_loader, 'l1', gpu=args.gpu)
    test(model_2, criterion_2, test_loader, 'l2', gpu=args.gpu)
    test(model_3, criterion_3, test_loader, 'lyapunov', gpu=args.gpu)


if __name__ == '__main__':
    main()
