#!/usr/bin/env python3
import os
import numpy as np
import sys
sys.path.append('..')
from collections import namedtuple
from glob import glob
import time
import gc
import pickle
from evaluate_pedastrain import evaluate
from train_utils import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'
dataset_path = '~/particle/trajnetplusplus_dataset/'

## don't forget change scheduler

train_window = 6

use_normalize_input = False
normalize_scale = 3

batch_divide = 1
TrainParams = namedtuple('TrainParams', ['epochs', 'batches_per_epoch', 'base_lr', 'batch_size'])
train_params = TrainParams(50, 1000 * batch_divide, 0.001, 16 / batch_divide)

train_particle_num = 60

correction_scale = 1 / 128.

model_name = 'pedestrian_ctsconv_rel_pos'

checkpoint_path = os.path.join('checkpoint_models', model_name)

# os.system('mkdir ' + checkpoint_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from datasets.pedestrain_pkl_loader import read_pkl_data
val_path = os.path.join(dataset_path, 'val')
train_path = os.path.join(dataset_path, 'train')
test_path = os.path.join(dataset_path, 'test')

def create_model():
    from models.pedestrain_model_rel_pos import ParticlesNetwork
    """Returns an instance of the network for training and evaluation"""
    model = ParticlesNetwork(radius_scale = 6, correction_scale = correction_scale)
    return model


class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

def main():

    val_dataset = read_pkl_data(val_path, batch_size=8, shuffle=False, repeat=False)

    dataset = read_pkl_data(train_path, batch_size=train_params.batch_size, repeat=True, 
                            shuffle=True, max_num=train_particle_num)

    data_iter = iter(dataset)   
    
    model_ = create_model()
    # model = model_.to(device)
    model = MyDataParallel(model_).to(device)
    optimizer = torch.optim.Adam(model.parameters(), train_params.base_lr,betas=(0.9, 0.999), weight_decay=4e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.95) #0.9968
    
    def train_one_batch(model, batch, train_window=2):

        batch_size = train_params.batch_size

        inputs = ([
            batch['pos_enc'], batch['vel_enc'], 
            batch['pos0'], batch['vel0'], 
            batch['accel'], None, 
            batch['man_mask']
        ])

        # print_inputs_shape(inputs)
        # print(batch['pos0'])
        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = batch['pos1']
        # print(pr_pos1)

        # losses = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])
        losses = 0.5 * loss_fn(pr_pos1, gt_pos1, torch.sum(batch['man_mask'], dim = -1).unsqueeze(-1) - 1, batch['man_mask'])
        del gt_pos1

        # pos_2s = batch['pos_2s']
        # vel_2s = batch['vel_2s']
        pos0 = batch['pos0']
        vel0 = batch['vel0']
        for i in range(train_window-1):
            pos_enc = torch.unsqueeze(pos0, 2)
            # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)
            vel_enc = torch.unsqueeze(vel0, 2)
            # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)
            # del pos_enc, vel_enc
            accel = pr_vel1 - vel_enc[...,-1,:]
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, accel, None,
                      batch['man_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            del pos_enc, vel_enc
            
            pr_pos1, pr_vel1, states = model(inputs, states)

            gt_pos1 = batch['pos'+str(i+2)]
            clean_cache(device)

            # losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
            #                    model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])

            losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
                               torch.sum(batch['man_mask'], dim = -1).unsqueeze(-1) - 1, batch['man_mask'])


            # pr_pos1, pr_vel1 = pr_pos2, pr_vel2
            # print(pr_pos1)


        total_loss = 128 * torch.sum(losses,axis=0) / batch_size


        return total_loss
    
    epochs = train_params.epochs
    batches_per_epoch = train_params.batches_per_epoch   # batchs_per_epoch.  Dataset is too large to run whole data. 
    data_load_times = []  #Per batch 
    train_losses = []
    valid_losses = []
    valid_metrics_list = []
    min_loss = None

    for i in range(epochs):
        epoch_start_time = time.time()

        model.train()
        epoch_train_loss = 0 
        sub_idx = 0

        print("training ... epoch " + str(i + 1), end='')
        for batch_itr in range(batches_per_epoch):

            data_fetch_start = time.time()
            batch = next(data_iter)
            
            if sub_idx == 0:
                optimizer.zero_grad()
                if (batch_itr // batch_divide) % 25 == 0:
                    print("... batch " + str((batch_itr // batch_divide) + 1), end='', flush=True)
            sub_idx += 1

            batch_size = len(batch['pos0'])

            batch_tensor = {}
            convert_keys = (['pos' + str(i) for i in range(train_window + 1)] + 
                            ['vel' + str(i) for i in range(train_window + 1)] + 
                            ['pos_enc', 'vel_enc'])

            for k in convert_keys:
                batch_tensor[k] = torch.tensor(np.stack(batch[k])[:,:train_particle_num], 
                                               dtype=torch.float32, device=device)
                
            if use_normalize_input:
                batch_tensor, max_pos = normalize_input(batch_tensor, normalize_scale, train_window)

            for k in ['man_mask']:
                batch_tensor[k] = torch.tensor(np.stack(batch[k])[:,:train_particle_num], 
                                               dtype=torch.float32, device=device)

            batch_tensor['man_mask'] = batch_tensor['man_mask'].squeeze(-1)
            accel = batch_tensor['vel0'] - batch_tensor['vel_enc'][...,-1,:]
            # accel = torch.zeros(batch_size, 1, 3).to(device)
            batch_tensor['accel'] = accel
            del batch

            data_fetch_latency = time.time() - data_fetch_start
            data_load_times.append(data_fetch_latency)

            current_loss = train_one_batch(model, batch_tensor, train_window=train_window)
            if sub_idx < batch_divide:
                current_loss.backward(retain_graph=True)
            else:
                current_loss.backward()
                optimizer.step()
                sub_idx = 0
            del batch_tensor

            epoch_train_loss += float(current_loss)
            del current_loss
            clean_cache(device)

            if batch_itr == batches_per_epoch - 1:
                print("... DONE", flush=True)

        train_losses.append(epoch_train_loss)

        model.eval()
        with torch.no_grad():
            valid_total_loss, valid_metrics = evaluate(model.module, val_dataset, 
                                                       train_window=train_window, max_iter=100, 
                                                       device=device, batch_size=val_dataset.batch_size)

        valid_losses.append(float(valid_total_loss))
        valid_metrics_list.append(valid_metrics)
        
        # torch.save(model.module, os.path.join(checkpoint_path, model_name + '_' + str(i) + ".pth"))

        if min_loss is None:
            min_loss = valid_losses[-1]
            torch.save(model.module, 'weights/' + model_name + ".pth")

        if valid_losses[-1] < min_loss:
            min_loss = valid_losses[-1] 
            best_model = model
            torch.save(model.module, 'weights/' + model_name + ".pth")

            #Add evaluation Metrics

        epoch_end_time = time.time()

        print('epoch: {}, train loss: {}, val loss: {}, epoch time: {}, lr: {}, {}'.format(
            i + 1, train_losses[-1], valid_losses[-1], 
            round((epoch_end_time - epoch_start_time) / 60, 5), 
            format(get_lr(optimizer), "5.2e"), model_name
        ))

        scheduler.step()
    
        with open('results/{}_val_metrics.pickle'.format(model_name), 'wb') as f:
            pickle.dump(valid_metrics_list, f)
        

def final_evaluation():
    
    test_dataset = read_pkl_data(val_path, batch_size=6, shuffle=False, repeat=False)
    
    trained_model = torch.load('weights/' + model_name + '.pth')
    trained_model.eval()
    
    with torch.no_grad():
        valid_total_loss, valid_metrics = evaluate(trained_model, test_dataset, 
                                                   train_window=train_window, max_iter=len(test_dataset), 
                                                   device=device, start_iter=120)
    
    with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f:
        pickle.dump(valid_metrics, f)
        
        
if __name__ == '__main__':
    main()
    
    final_evaluation()
    
    
    