#!/usr/bin/env python
import argparse
import sys
import math
import torch
import logging
from torch import optim
from datetime import datetime
from ignite.metrics import Loss
from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping
from ignite.handlers.param_scheduler import LRScheduler
from ChamferLoss import chamfer_distance_with_batch, chamfer_distance_with_batch_mean
from PointAutoencNet import PointAutoencNet
# from utils import *
from datetime import timedelta
from torch.utils.data import DataLoader
from TearingNet_main.dataloaders.kittimulobj_loader import KITTIMultiObjectDataset
from TearingNet_main.dataloaders.modelnet_loader import ModelNet40
from TearingNet_main.dataloaders.shapenet_55_loader import ShapeNet


settings = {
    "patience": 200,
    "learning_rate": 0.004,
    "l2": 0.0001,

    "epoch": 1,
    "batch": 0,
    "ltime_train": timedelta(milliseconds=1),
    "ltime_eval_train": timedelta(milliseconds=1),
    "ltime_eval_val": timedelta(milliseconds=1),
    "n_train_samples": 0,
    "loss_sum": 0,
}

# path to store the best models
settings['MODEL_PATH'] = './models/'

def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = batch[0].cuda(), batch[1]
    # x = x.transpose(2,1) # potentially change the two last dimensions depending on the dataset
    x_norm = x - torch.mean(x, -1, keepdim=True)
    x_norm /= torch.max(torch.linalg.norm(x_norm, dim=-2, keepdim=True), -1, keepdim=True)[0]

    if options.model == 'own':
        x_pred, loss1, loss2, max_ind1, max_ind2, sel_vec1, sel_vec2 = model(x_norm)
        loss = chamfer_distance_with_batch(x_norm, x_pred)
        chamloss = loss.item()


    optimizer.param_groups[0]['lr'] = 0.004 * 0.5 ** (trainer.state.epoch / 50)

    loss = loss * 100 + (loss1 + loss2)
    loss.backward()
    optimizer.step()

    with open('logs/' + settings['file_date'].strftime(options.model + '_' + options.dataset + '_log_%m-%d-%Y_%I-%M-%S.csv'), mode='a') as file:
        file.write(f"{settings['batch']};{loss.item()};{chamloss};{loss1.item()};{loss2.item()}\n")

    return x, x_pred, chamloss, loss1, loss2, loss.item()

def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch[0].cuda(), batch[1]

        x_mean = torch.mean(x, -1, keepdim=True)
        x_norm = x - x_mean
        x_scale = torch.max(torch.linalg.norm(x_norm, dim=-2, keepdim=True), -1, keepdim=True)[0]
        x_norm /= x_scale

        if options.model == 'own':
            x_pred, loss1, loss2, max_ind1, max_ind2, sel_vec1, sel_vec2 = model(x_norm)

        x_pred_scale = x_pred * x_scale + x_mean

        return x, x_pred_scale

def score_function_autoenc(engine):
    return - engine.state.metrics['chamloss']

def log_set_time_wrapper(log_func):
    def new_log_func(*args, **kwargs):
        res = log_func(*args, **kwargs)
        settings['time_last_log'] = datetime.now()
        return res
    return new_log_func

def getOptions(args):
    parser = argparse.ArgumentParser(description="Parses command.")
    parser.add_argument("-d", "--dataset", default="shapenet", help="The dataset used.")
    parser.add_argument("-m", "--model", default="own", help="The model used.")
    parser.add_argument("-b", "--batch_size", type=int, default=64, help="The batch size.")
    parser.add_argument("-k", "--n_knn", type=int, default=10, help="The number of nearest neighbors processed in one knn_conv.")
    parser.add_argument("-f", "--model_file", default='None', help="The location of the saved model.")
    parser.add_argument("-e", "--epoch", type=int, default=-1, help="The epoch to start with.")
    parser.add_argument("-a", '--ablation', nargs='+', type=int, default=[6, 6, 6, 6], help="The decrease and increase factors")
    parser.add_argument("-cc", "--codeword_channels", type=int, default=9, help="The number of channels in the codeword.")
    options = parser.parse_args(args)
    return options

if __name__ == '__main__':
    settings['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(settings['device'])

    options = getOptions(sys.argv[1:])
    print(options.dataset)
    print(options.model)
    print(options.ablation)
    print(options.codeword_channels)
    print(options.batch_size)
    print(options.n_knn)
    print(options.model_file)
    print(options.epoch)

    settings['min_sel_loss'] = 10000
    settings['loss_wf'] = 0.0
    settings['file_date'] = datetime.now()

    logging.basicConfig(format='%(asctime)s %(message)s',
                        datefmt='%m/%d/%Y %I:%M:%S %p',
                        filename=settings['MODEL_PATH'] + 'CV_LOG/' + settings['file_date'].strftime(options.model + '_' + options.dataset + '_log_%m-%d-%Y_%I-%M-%S' + '.log'),
                        level=logging.ERROR)
    logging.error('Device used: %s', settings['device'])


    with open('logs/' + settings['file_date'].strftime(options.model + '_' + options.dataset + '_log_%m-%d-%Y_%I-%M-%S' + '.csv'), mode='w') as file:
        file.write('batch;cum_loss;cd;loss1;loss2\n')

    settings['time_last_log'] = datetime.now()
    print('Bulding dataset at:', settings['time_last_log'])
    logging.error('Building dataset ...')

    if options.dataset=='kitti5':
        valid_ds = KITTIMultiObjectDataset(split='test_5x5', transpose_fp=True)
        train_ds = KITTIMultiObjectDataset(split='train_5x5', transpose_fp=True)
    elif options.dataset=='kitti4':
        valid_ds = KITTIMultiObjectDataset(split='test_4x4', transpose_fp=True)
        train_ds = KITTIMultiObjectDataset(split='train_4x4', transpose_fp=True)
    elif options.dataset=='kitti3':
        valid_ds = KITTIMultiObjectDataset(split='test_3x3', transpose_fp=True)
        train_ds = KITTIMultiObjectDataset(split='train_3x3', transpose_fp=True)
    elif options.dataset=='kitti6':
        valid_ds = KITTIMultiObjectDataset(split='test_6x6', transpose_fp=True)
        train_ds = KITTIMultiObjectDataset(split='train_6x6', transpose_fp=True)
    elif options.dataset=='modelnet':
        train_ds = ModelNet40(phase='train', transpose_fp=True)
        valid_ds = ModelNet40(phase='test', transpose_fp=True)
    elif options.dataset=='shapenet':
        train_ds = ShapeNet(subset='train')
        valid_ds = ShapeNet(subset='test')
    
    train_loader = DataLoader(dataset=train_ds, batch_size=options.batch_size, shuffle=True, num_workers=5)
    valid_loader = DataLoader(dataset=valid_ds, batch_size=options.batch_size, num_workers=5)

    print('Dataset constructed at:', datetime.now(), 'after:', datetime.now() - settings['time_last_log'])
    logging.error('Time elapsed for Dataset construction: %s', str(datetime.now() - settings['time_last_log']))

    test_results = []
    cv_fold = 0
    for train_loader, val_loader in [(train_loader, valid_loader)]:

        cv_fold += 1
        settings['best_val_loss'] = math.inf

        if options.model == 'own':
            model = PointAutoencNet(n_knn=options.n_knn, fdfu=options.ablation, cc=options.codeword_channels).to(device=settings['device'])
        if options.model_file != 'None':
            model.load_state_dict(torch.load(options.model_file))
        optimizer = optim.AdamW(model.parameters(), lr=settings['learning_rate'], weight_decay=settings['l2'])

        # setup training an validation engines
        trainer = Engine(train_step)

        evaluator = Engine(validation_step)
        if options.model == 'own':
            handler = EarlyStopping(patience=settings['patience'], score_function=score_function_autoenc, trainer=trainer)

        # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
        evaluator.add_event_handler(Events.COMPLETED, handler)

        if options.model == 'own':
            me_loss = Loss(chamfer_distance_with_batch_mean, output_transform=lambda x: (x[0], x[1]), device=settings['device'])
        me_loss.attach(evaluator, 'chamloss')

        # batch/epoch logging
        settings['time_last_log'] = datetime.now()
        settings['n_batch_train'] = len(train_loader)
        settings['n_batch_val'] = len(val_loader)

        @trainer.on(Events.ITERATION_COMPLETED(every=1))
        def log_training_loss(trainer):
            out = trainer.state.output
            # metrics for train
            n_sample = out[0].shape[0]
            settings['n_train_samples'] += n_sample
            settings['loss_sum'] += out[2]

            knn_loss1 = out[3]
            knn_loss2 = out[4]

            l2_norm = settings['l2'] * sum(p.pow(2.0).sum() for p in model.parameters())

            settings['batch'] += 1
            timediff = datetime.now() - settings['time_last_log']
            cur_prog = settings['batch'] / settings['n_batch_train']
            if trainer.state.epoch == 1 and options.epoch > 0:
                trainer.state.epoch = options.epoch
            if cur_prog < 1.0:
                print(f"\rFold {cv_fold} Epoch {trainer.state.epoch} - " 
                         f"[" + (math.floor(20 * cur_prog) - 1) * '=' + '>' + math.ceil((1 - cur_prog) * 20) * ' ' +
                         f"] Batch {settings['batch']}/{settings['n_batch_train']} \t" 
                         f"Loss: {out[5]:.2f}\tlosscd:  {out[2]:.4f}\tloss1: {knn_loss1:.2f}\tloss2: {knn_loss2:.2f}\t"
                         f"L2-norm: {optimizer.param_groups[0]['lr']:.4f}"
                         f"Time rem.: {abs(settings['ltime_train'].total_seconds() - timediff.total_seconds()):.1f}s",
                      end='')
            else:

                log_string = (
                    f"Fold {cv_fold} Epoch {trainer.state.epoch} - "
                    f"[" + 20 * '=' + f"] Batch {settings['batch']}/{settings['n_batch_train']} \t"
                    f"Loss: {out[2]:.2f}\tL2-norm: {l2_norm:.2f}"
                    f"Time el: {timediff.total_seconds():.1f}s"
                )
                logging.error(log_string)
                print("\r" + log_string)
                settings['batch'] = 0
                settings['time_last_log'] = datetime.now()
                if timediff > settings['ltime_train']:
                    settings['ltime_train'] = timediff
                log_string = (
                    f"\rTraining Dataset Results - Fold {cv_fold} Epoch: {trainer.state.epoch} "
                    f"\n\tLR: {optimizer.param_groups[0]['lr']:.5f}  "
                    f"\n\tAvg loss: {settings['loss_sum'] / settings['n_batch_train']:.3f} "
                )
                logging.error(log_string)
                print(log_string)

                # reset values for metrics
                settings['n_train_samples'] = 0
                settings['loss_sum'] = 0.0

        @evaluator.on(Events.ITERATION_COMPLETED(every=1))
        def log_evaluation(evaluator):
            settings['batch'] += 1
            timediff = datetime.now() - settings['time_last_log']
            dset_type = 'validation'
            ltime_dif = settings['ltime_eval_val']
            progress = settings['batch'] / settings['n_batch_val']
            print(f"\rEvaluating model on the " + dset_type +
                  f" dataset:\t{(progress * 100):2.1f}%\t"
                  f"Time rem.: {abs(ltime_dif.total_seconds() - timediff.total_seconds()):.1f}s", end='')
            if progress >= 1.0:
                if timediff > settings['ltime_eval_val']:
                    settings['ltime_eval_val'] = timediff

        @trainer.on(Events.EPOCH_COMPLETED)
        @log_set_time_wrapper
        def log_validation_results(trainer):
            evaluator.run(val_loader)
            settings['batch'] = 0
            metrics = evaluator.state.metrics
            timediff = datetime.now() - settings['time_last_log']
            if options.model == 'own':
                chamloss = metrics['chamloss']
            elif options.model == 'ownsem':
                print(metrics['chamloss'])
                chamloss = metrics['chamloss'].mean()

            log_string = (
                f"\rValidation Dataset Results - Fold {cv_fold} Epoch: {trainer.state.epoch}  "
                f"\n\tTime el. for val: {timediff.total_seconds():.1f}s"
                f"\n\tAvg loss: {chamloss:.3f}"
            )
            logging.error(log_string)
            print(log_string)

            if chamloss < settings['best_val_loss']:
                settings['best_val_loss'] = metrics['chamloss']
                torch.save(model.state_dict(), settings['MODEL_PATH'] + settings['file_date'].strftime(options.model + '_' + options.dataset + '_log_%m-%d-%Y_%I-%M-%S') + '.pt')
                print('Model saved at: ' + settings['MODEL_PATH'] + settings['file_date'].strftime(options.model + '_' + options.dataset + '_log_%m-%d-%Y_%I-%M-%S') + '.pt')
                logging.error('Model saved as: ' + settings['file_date'].strftime(options.model + '_' + options.dataset + '_log_%m-%d-%Y_%I-%M-%S') + '.pt')

        trainer.run(train_loader, max_epochs=300)
