#!/usr/bin/env python3

import os
import json
import sys
import math
import argparse
import random
import signal
import datetime
import itertools

from utils import mkdir, Logger
from dataloader import SimulationDataSet
from models import define_G, define_D, get_scheduler, update_learning_rate, GANLoss
from evaluate import Evaluator
from config import config

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.sampler import Sampler
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data.sampler import RandomSampler
from torch.utils.data import Subset

from torchsummary import summary

import time

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")


training_started = False
STOP_TRAINING = False

train_start_indices_speed = [0,334,1002,1670,2004,2672,3006,3340,4008,4676,5010,5678,6012,6346,6680,7348,7682,8350,9018,9352,9686,10020,10688,11022,11690,12024,12692,13026,13360]
test_start_indices_speed = [668,2338,3674,5344,7014,8684,10354,12358]
val_start_indices_speed = [1336,4342,8016,11356]
collector = [train_start_indices_speed,val_start_indices_speed,test_start_indices_speed]

def signal_handler(sig, frame):
    if not training_started:
        sys.exit(0)

    print('Stoping the training...')
    global STOP_TRAINING
    STOP_TRAINING = True


def create_directories():
    mkdir(config['output_dir'])
    mkdir(os.path.join(config['output_dir'], args.model_name))


def get_dataconf_file(args):
    return 's_dataconf.txt'

# splits the indices of all data into train-/val-/testset
def test_validation_test_split(dataset, model = 's', batch_size = 3, sim_length = 334):
    # splitting speed-model indices according to batch-size and simulation-length
    if model == 's':
        print('Splitting for s')
        batch_count = sim_length//batch_size
        indices = []
        for x in range(0,3):
            temp = []
            for batch_start in collector[x]:
                temp += [i for i in range(batch_start, batch_start + batch_count * batch_size)]
            indices.append(temp)

        # indices[0]: train, indices[1]: val, indices[2]: test
        return indices[0], indices[1], indices[2]

# saving the net
def save_models(net_g, net_d, args, epoch):
    net_g_model_out_path = "./{0}/{1}/netG_{1}_model_epoch_{2}.pth".format(config['output_dir'], args.model_name, epoch)
    #net_d_model_out_path = "./{0}/{1}/netD_{1}_model_epoch_{2}.pth".format(config['output_dir'], args.model_name, epoch)
    if args.parallel:
        torch.save(net_g.module, net_g_model_out_path)
        #torch.save(net_d.module, net_d_model_out_path)
    else:
        torch.save(net_g, net_g_model_out_path)
        #torch.save(net_d, net_d_model_out_path)

# splits given indices in parts with given size and shuffles the parts
def shuffleDataset(indices, size):
    ind = []
    tmp = []
    num = len(indices)//size
    for batches in range(num):
        tmp = indices[:size]
        del indices[0:size]
        ind.append(tmp)
    np.random.shuffle(ind)
    return [b for bs in ind for b in bs]


parser = argparse.ArgumentParser(description='The training script of the flowPredict pytorch implementation')
parser.add_argument('--data', dest='data_dir',default=config['data_dir'], required=False, help='Root directory of the generated data.')
parser.add_argument('--model-name', dest='model_name', default=config['model_name'], required=False, help='Name of the current model being trained. res or unet')
#parser.add_argument('--use-pressure', dest='use_pressure', required=False, action='store_true', default=config['use_pressure'], help='Should the pressure field images to considered by the models')
parser.add_argument('--model-type', dest='model_type', action='store', default=config['model_type'], choices=['c', 'vd', 's', 'o'], required=False,
                    help='Type of model to be build. \'c\' - baseline, \'vd\' - fluid viscosity and density, \'s\' - inflow speed, \'o\' - object')
parser.add_argument('--cuda', dest='cuda', action='store_true', default=config['cuda'], help='Should CUDA be used or not')
parser.add_argument('--threads', dest='threads', type=int, default=config['threads'], help='Number of threads for data loader to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=config['batch_size'], help='Training batch size.')
parser.add_argument('--seed', dest='seed', type=int, default=config['seed'], help='Random seed to use. Default=123')
parser.add_argument('--niter', type=int, dest='niter', default=config['niter'], help='Number of iterations at starting learning rate')
parser.add_argument('--epochs', dest='epochs', type=int, default=config['epochs'], help='Number of epochs for which the model will be trained')
parser.add_argument('--niter_decay', type=int, dest='niter_decay', default=config['niter_decay'], help='Number of iterations to linearly decay learning rate to zero')
parser.add_argument('--lr_policy', type=str, default=config['lr_policy'], help='learning rate policy: lambda|step|plateau|cosine')
parser.add_argument('--evaluate', default=config['evaluate'], action='store_true', dest='evaluate' , help='Evaluate the trained model at the end')
parser.add_argument('--no-train', default=config['no_train'], action='store_true', dest='no_train' , help='Do not train the model with the training data')
parser.add_argument('--model-path', default=config['model_path'], action='store', dest='model_path' , help='Optional path to the model\'s weights.')
parser.add_argument('--g_nfg', type=int, dest='g_nfg', default=-1, help='Number of feature maps in the first layers of ResNet')
parser.add_argument('--g_layers', type=int, dest='g_layers', default=-1, help='ResNet blocks in the middle of the network')
parser.add_argument('--g_output_nc', type=int, dest='g_output_nc', default=-1, help='Number of output channels of the generator network')
parser.add_argument('--g_input_nc', type=int, dest='g_input_nc', default=-1, help='Number of input channels of the generator network')
parser.add_argument('--output-dir', dest='output_dir', default=None, help='The output directory for the model files')
parser.add_argument('--no-mask', dest='mask', default=config['mask'], action='store_false', help='Disable the mask for the model')
parser.add_argument('--no-noise', dest='noise', default=config['noise'], action='store_false', help='Use noise for the input')
parser.add_argument('--no-crops', dest='crops', default=config['crops'], action='store_false', help='Use crops for the input')
parser.add_argument('--parallel', dest='parallel', default=config['parallel'], action='store_true', help='Use multiple gpus parallel')
parser.add_argument('--no-lstm', dest='lstm', default=config['lstm'], action='store_false', help='Use no lstm in the model')
parser.add_argument('--lstm-layers', dest='lstm_layers', type=int, default=config['lstm_layers'], help='Number of LSTM-layers used')
parser.add_argument('--random', dest='random', default=config['random'], action='store_true', help='Use random training')
parser.add_argument('--no-batchreset', dest='batchreset', action='store_false', default=config['batchreset'], help='Reset the LSTM not after every batch, instead after given number of iterations with --reset-lstm-iter')
#parser.add_argument('--batchshuffle', dest='shuffle_batch', default=config['shuffle_batch'], action='store_true', help='Shuffle val and train data batchwise')
parser.add_argument('--lstm-use-params', dest='use_params', default=config['use_params'], action='store_true', help='Give also the LSTM the parameter')
parser.add_argument('--reset-lstm-iter', dest='reset_iter', default=config['reset_iter'], help='When the LSTM is not resetted after every batch, specify the number of batches after which the LSTM should be resetted')
#parser.add_argument('--random-until-epoch', dest='random_until', type=int, default=config['random_until'], help='If positive then after the given number of epochs the training is changed from random to batch-wise ordered training')
#parser.add_argument('--freezing', dest='freezing', default=config['freezing'], action='store_true', help='Use freeze training after random_until number of epochs')

args = parser.parse_args()

parameterized = args.model_type == 's'

#mixed_training = args.random_until > 0

# number of input/output channels generator
config['g_input_nc'] = 2
config['g_output_nc'] = 2

# number of input channels discriminator
if args.mask:
    config['d_input_nc'] = 2*config['g_input_nc'] + 1
else:
    config['d_input_nc'] = 2*config['g_input_nc']

if parameterized:
    config['g_input_nc'] += 1 if args.model_type == 's' else 2 if args.model_type == 'vd' else 0 #TODO

# input generator +1 when using mask
if args.mask: config['g_input_nc'] += 1


if args.g_layers != -1: config['g_layers'] = args.g_layers
if args.g_nfg != -1: config['g_nfg'] = args.g_nfg
if args.g_input_nc != -1: config['g_input_nc'] = args.g_input_nc
if args.g_output_nc != -1: config['g_output_nc'] = args.g_output_nc
if args.output_dir is not None: config['output_dir'] = args.output_dir


args.model_name = '{}_{}_l{}_ngf{}'.format(args.model_type, args.model_name, config['g_layers'], config['g_nfg'])

if args.mask:
    args.model_name = '{}_m'.format(args.model_name)
if args.noise:
    args.model_name = '{}_n'.format(args.model_name)
if args.crops:
    args.model_name = '{}_c'.format(args.model_name)

create_directories()

sys.stdout = Logger(os.path.join(config['output_dir'], 'log.txt'))


print('===> Setting up basic structures ')

if (args.cuda or args.parallel) and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

device = torch.device("cuda:0" if args.cuda or args.parallel else "cpu")

if not args.no_train: signal.signal(signal.SIGINT, signal_handler)

model_type = args.model_type
batch_size = args.batch_size
num_epochs = args.epochs
threads = args.threads
model_name = args.model_name
#shuffle_batch = args.shuffle_batch
use_params = args.use_params

print('--Model name:', args.model_name)
print('--model type:', model_type)
#print('--use pressure:', args.use_pressure)
print('--batch size:', batch_size)
print('--num epochs:', num_epochs)
print('--learning rate policiy:', args.lr_policy)
print('--worker threads:', threads)
print('--cuda:', args.cuda)
print('--device:', device)
print('--gen. input channels:', config['g_input_nc'])
print('--gen. output channels:', config['g_output_nc'])
print('--desc. input channels:', config['d_input_nc'])
print('--mask:', args.mask)
print('--lstm:', args.lstm)
print('--lstm-layers:', args.lstm_layers)
#print('--shuffle-batch: ', shuffle_batch)
print('--lstm-use-params: ', use_params)
#print('--mixed-training: ', args.random_until)
print('--random: ', args.random)
#print('--freezing: ', args.freezing)

print('===> Loading datasets')

dataconf_file = get_dataconf_file(args)
dataset = SimulationDataSet(args.data_dir, dataconf_file, args)

sim_length = 334

train_indices, val_indices, test_indices = test_validation_test_split(dataset, model=model_type, batch_size =batch_size, sim_length=sim_length)

print('dataset length:', len(dataset))
trainset = Subset(dataset, train_indices)
valset = Subset(dataset, val_indices)
testset = Subset(dataset, test_indices)

# random or mixed training
if args.random:
    print("hallo random")
    args.random = True
    test_sampler = RandomSampler(test_indices)
    val_sampler = RandomSampler(val_indices)
    train_sampler = RandomSampler(train_indices)

train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, shuffle=False, num_workers=threads, pin_memory=(args.cuda or args.parallel) == True)
val_loader = DataLoader(valset, batch_size=batch_size, sampler=val_sampler, shuffle=False, num_workers=threads, pin_memory=(args.cuda or args.parallel) == True)
test_loader = DataLoader(testset, batch_size=batch_size, sampler=test_sampler, shuffle=False, num_workers=threads, pin_memory=(args.cuda or args.parallel) == True)

print('--training samples count:', len(train_indices))
print('--validation samples count:', len(val_indices))
print('--test samples count:', len(test_indices))

now = datetime.datetime.now()
date = now.strftime("%d-%m-%Y:%H:%M:%S")
print('--date:', date)

with open(os.path.join(config['output_dir'], 'date_{}'.format(date)), 'w') as dh:
    dh.write(date)

print('===> Loading model')

net_g = define_G(config['g_input_nc'], config['g_output_nc'], config['g_nfg'], n_blocks=config['g_layers'], use_lstm = args.lstm, lstm_layers = args.lstm_layers, gpu_id=device, use_params = use_params, args=args).float()
net_d = define_D(config['d_input_nc'], config['d_nfg'], n_layers_D=config['d_layers'], gpu_id=device, args=args).float()

print('-> Printing generator structure')
print(net_g)

print('-> Printing discriminator structure')
print(net_d)

if args.parallel and torch.cuda.device_count() > 1:
    print("--using", torch.cuda.device_count(), "GPUs")
    net_g = nn.DataParallel(net_g)
    net_d = nn.DataParallel(net_d)
    net_d.to(device)
    net_g.to(device)
elif args.cuda:
    net_d.to(device)
    net_g.to(device)
else:
    net_d.to(device)
    net_g.to(device)

optimizer_g = optim.Adam(net_g.parameters(), lr=config['adam_lr'], betas=(config['adam_b1'], config['adam_b2']))
optimizer_d = optim.Adam(net_d.parameters(), lr=config['adam_lr'], betas=(config['adam_b1'], config['adam_b2']))

net_g_scheduler = get_scheduler(optimizer_g, args)
net_d_scheduler = get_scheduler(optimizer_d, args)

criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device)

train_loader_len = len(train_loader)
losses_path = os.path.join(config['output_dir'], 'losses.txt')
val_losses_path = os.path.join(config['output_dir'], 'val_losses_test.txt')

# lstm has to be set (non-)recursive, resetted and given the simulation parameter if it uses it 
# this loop searches for the lstm-module so its functions can be accessed later on
# only works for one lstm
if args.lstm:
    lstm = None
    for idx, m in enumerate(net_g.modules()):
        if m.__class__.__name__ == "LSTMblock":
            lstm = m
            lstm.set_recursive(False, device)
            break



if not args.no_train:
    print('===> Starting the training loop')

if args.mask:
    MASK = dataset.get_mask().to(device)

training_started = True
dataset.test()

if args.lstm:
    # if lstm should not be resetted after every batch but instead after given number of iterations
    if not args.batchreset:
        begrenzer = args.reset_iter
        cou = 0
    lstm.set_recursive(False, device)

# cuDNN benchmark chooses the best convolution algorithm
# needs a little bit of time before the application starts but then reduces the needed computation time for the convolutions
torch.backends.cudnn.benchmark = True

starttime = time.time()
freezed = False
for epoch in range(num_epochs if not args.no_train else 0):
    epoch_loss_d = 0
    epoch_loss_g = 0
    
    iteration = 1
    for batch in train_loader:
        net_g.train()
        net_d.train()

        # lstm stuff for every batch
        if args.lstm:
            if use_params:
                lstm.set_parameter(batch[2].to(device))
            if args.batchreset:
                lstm.set_recursive(False, device)
            else:
                if cou == begrenzer:
                    lstm.set_recursive(False, device)
                    cou = 0
                cou += 1

        real_a, real_b = batch[0].to(device), batch[1].to(device)

        # generate fake image
        if parameterized:
            params = batch[2].to(device)
            fake_b = net_g((real_a, params))
        else:
            fake_b = net_g(real_a)


        ##############################
        # Training the descriminator #
        ##############################
        optimizer_d.zero_grad()


        fake_ab = torch.cat((real_a, fake_b), 1)
        pred_fake = net_d(fake_ab.detach())
        loss_d_fake = criterionGAN(pred_fake, False)

        real_ab = torch.cat((real_a, real_b), 1)
        pred_real = net_d(real_ab)
        loss_d_real = criterionGAN(pred_real, True)

        loss_d = (loss_d_fake + loss_d_real) * 0.5

        loss_d.backward()
        optimizer_d.step()

        ##############################
        #   Training the generator   #
        ##############################
        optimizer_g.zero_grad()

        fake_ab = torch.cat((real_a, fake_b), 1)
        pred_fake = net_d(fake_ab)
        loss_g_gan = criterionGAN(pred_fake, True)
        loss_g_l1 = criterionL1(fake_b, real_b) * config['lambda_L1']

        loss_g = loss_g_gan + loss_g_l1

        loss_g.backward()
        optimizer_g.step()

        epoch_loss_d += loss_d.item()
        epoch_loss_g += loss_g.item()

        print("> Epoch[{}]({}/{}): Loss_D: {:.5f} Loss_G: {:.5f}".format(
            epoch, iteration, train_loader_len, loss_d.item(), loss_g.item()))
        iteration += 1

        if STOP_TRAINING:
            totaltime = time.time() - starttime
            print('> Mean running time for one epoch: {} seconds'.format(totaltime/num_epochs))
            print('> Total running time for {} epochs and ({}/{}): {} seconds'.format(epoch, iteration, train_loader_len, totaltime))
            print('> Saving the model now...')
            save_models(net_g, net_d, args, epoch)
            print('> Model saved.')
            sys.exit(0)

    update_learning_rate(net_g_scheduler, optimizer_g)
    update_learning_rate(net_d_scheduler, optimizer_d)

    epoch_loss_d /= train_loader_len
    epoch_loss_g /= train_loader_len

    with open(losses_path, 'a') as losses_hand:
        losses_hand.write('epoch: {}, gen:{:.5f}, desc:{:.5f}\n'.format(epoch, epoch_loss_g, epoch_loss_d))

    if epoch == 0 or (epoch > 20 and epoch % 10  == 0) or (epoch > 30 and epoch % 5  == 0):
        save_models(net_g, net_d, args, epoch)
        print("> Checkpoint saved to {}".format(os.path.join("checkpoints", args.model_name)))

    if epoch % 5  == 0 or epoch == num_epochs-1:
        avg_psnr = 0
        avg_mse = 0
        with torch.no_grad():
            for batch in val_loader:
                if args.lstm: lstm.set_recursive(False, device)
                if use_params: lstm.set_parameter(batch[2].to(device))
                input_img, target = batch[0].to(device), batch[1].to(device)
                if parameterized:
                    params = batch[2].to(device)
                    prediction = net_g((input_img, params))
                else:
                    prediction = net_g(input_img)

                if args.mask:
                    for i,j in itertools.product(range(prediction.shape[0]), range(prediction.shape[1])):
                        prediction[i][j] = MASK * prediction[i][j]

                mse = criterionMSE(prediction, target)
                psnr = 10 * math.log10(1 / mse.item())
                avg_mse += mse
                avg_psnr += psnr
            avg_psnr /= len(val_loader)
            avg_mse /= len(val_loader)

            print("> Val Avg. PSNR: {:.5} dB".format(avg_psnr))
            with open(val_losses_path, 'a') as losses_hand:
                losses_hand.write('epoch:{}, psnr:{:.5f}, mse:{:.5f}\n'.format(epoch, avg_psnr, avg_mse))

if not args.no_train:
    totaltime = time.time() - starttime
    print('> Mean running time for one epoch: {} seconds'.format(totaltime/num_epochs))
    print('> Total running time for training {} epochs: {} seconds'.format(num_epochs,totaltime))
if not args.no_train:
    save_models(net_g, net_d, args, num_epochs)
    print("> Checkpoint saved to {}".format(os.path.join("checkpoints", args.model_name)))
training_started = False

torch.backends.cudnn.benchmark = False

evaluator = Evaluator(args, config['output_dir'], MASK if args.mask else None, device=device, parameterized = parameterized)
if args.evaluate:
    print('===> Evaluating model')

    net_g.eval()

    dataset.test()
    
    with torch.no_grad():

        if args.lstm: lstm.set_recursive(False, device)
        print('===> Evaluating with test set:')
        evaluator.set_output_name('test')
        #if args.lstm: 
        #    if use_params: evaluator.snapshots(net_g, test_sampler, dataset, samples=config['evaluation_snapshots_cnt'], lstm=lstm, device=device, use_params =use_params)
        #    else: evaluator.snapshots(net_g, test_sampler, dataset, samples=config['evaluation_snapshots_cnt'], lstm=lstm, device=device)
        #else: evaluator.snapshots(net_g, test_sampler, dataset, samples=config['evaluation_snapshots_cnt'])
        
        #if args.lstm: evaluator.individual_images_performance(net_g, test_loader, lstm=lstm, device = device, use_params = use_params)
        #else: evaluator.individual_images_performance(net_g, test_loader, device = device)

        if args.model_type == 's':

            print('===> Running simulations for s:')
            evaluator.set_output_name('simulations')
            
            if args.lstm:
                evaluator.run_full_simulation(net_g, dataset, 100, 300, sim_name = 'simulation_timings', saving_imgs=False,use_params=use_params,lstm =lstm)
            else:
                evaluator.run_full_simulation(net_g, dataset, 100, 300, sim_name = 'simulation_timings', saving_imgs=False,use_params=use_params)
            indices = [(668, 's'),(668+120, 's'),(2338, 's'),(2338+120, 's'),(3674, 's'),(3674+120, 's'),(5344, 's'),(5344+120, 's'),(7014, 's'),(7014+120, 's'),(8684, 's'),(8684+120, 's'),(10354, 's'),(10354+120, 's'),(12358, 's'),(12358+120, 's')]
            for start_index in indices:
                sim_num = (start_index[0] // sim_length)+1
                startpoint = start_index[0] % sim_length
                if args.lstm: lstm.set_recursive(True, device)
                print(start_index)
                if args.lstm: evaluator.run_full_simulation(net_g, dataset, start_index[0],
                                              config['full_simulaiton_samples'],
                                              sim_name = 'simulation_{}_sim{}_i{}'.format(start_index[1], sim_num, startpoint),use_params=use_params,lstm =lstm)
                else: evaluator.run_full_simulation(net_g, dataset, start_index[0],
                                              config['full_simulaiton_samples'],
                                              sim_name = 'simulation_{}_sim{}_i{}'.format(start_index[1], sim_num, startpoint))

            print('===> Evaluating recursively:')
            # start indices for all simulations from testset with start from frame 0 and 120
            indices = [(668, 's'),(668+120, 's'),(2338, 's'),(2338+120, 's'),(3674, 's'),(3674+120, 's'),(5344, 's'),(5344+120, 's'),(7014, 's'),(7014+120, 's'),(8684, 's'),(8684+120, 's'),(10354, 's'),(10354+120, 's'),(12358, 's'),(12358+120, 's')]
            indices_full = [(668, 's'),(2338, 's'),(3674, 's'),(5344, 's'),(7014, 's'),(8684, 's'),(10354, 's'),(12358, 's')]
            for start_index in indices_full:
                # simulation number
                sim_num = (start_index[0] // sim_length)+1
                # starting point
                startpoint = start_index[0] % sim_length
                if args.lstm: lstm.set_recursive(True, device)
                evaluator.set_output_name('recursive_{}_sim{}_i{}'.format(start_index[1], sim_num, startpoint))
                if args.lstm: evaluator.recusive_application_performance(net_g, dataset, start_index[0], samples=config['evaluation_recursive_samples'],use_params=use_params,lstm =lstm)
                else: evaluator.recusive_application_performance(net_g, dataset, start_index[0], samples=config['evaluation_recursive_samples'])

                # recursive application with numInit iterations with real images before it is used recursively
                #if args.lstm: lstm.set_recursive(True, device)
                #if args.lstm: evaluator.recusive_application_performance_cheatstart(net_g, dataset, start_index[0], samples=20,numInit =8,use_params=use_params,lstm =lstm)
                #else: evaluator.recusive_application_performance_cheatstart(net_g, dataset, start_index[0], samples=20,numInit =8)

                # recursive application with 
                #if args.lstm: lstm.set_recursive(True, device)
                #if args.lstm: evaluator.recusive_application_performance_morph(net_g, dataset, start_index[0], samples=config['evaluation_recursive_samples'],use_params=use_params,lstm =lstm)
                #else: evaluator.recusive_application_performance_morph(net_g, dataset, start_index[0], samples=config['evaluation_recursive_samples'])

        print('===> Finished evaluation')