#!/usr/bin/env python3
import os
import numpy as np
import sys
sys.path.append('..')
from datasets.helper import get_lane_direction
from collections import namedtuple
from glob import glob
import time
import gc
import pickle
from utils.deeplearningutilities.tf import Trainer, MyCheckpointManager
from evaluate_network_equivariant import evaluate
from argoverse.map_representation.map_api import ArgoverseMap
from train_utils import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
dataset_path = '~/particle/argoverse/argoverse_forecasting/'
lane_path = '~/particle/TrafficFluids/datasets/'

use_lane = False

teacher_forcing = False
train_window = 3

use_normalize_input = False
normalize_scale = 3

batch_divide = 4
TrainParams = namedtuple('TrainParams', ['epochs', 'batches_per_epoch', 'base_lr', 'batch_size'])
train_params = TrainParams(100, 10 * batch_divide, 0.001, 16 / batch_divide)

model_name = 'reg_ctsconv_first_10_3'

checkpoint_path = os.path.join('checkpoint_models', model_name)

os.system('mkdir ' + checkpoint_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if use_lane:
    from datasets.argoverse_lane_loader import read_pkl_data
    val_path = os.path.join(dataset_path, 'val', 'lane_data')
    train_path = os.path.join(dataset_path, 'train', 'lane_data')
else:
    from datasets.argoverse_pickle_loader import read_pkl_data
    val_path = os.path.join(dataset_path, 'val', 'clean_data')
    train_path = os.path.join(dataset_path, 'train', 'clean_data')

def create_model():
    from models.reg_equivariant_model import ParticlesNetwork
    """Returns an instance of the network for training and evaluation"""
    model = ParticlesNetwork(radius_scale = 40, layer_channels = [8, 16, 8, 8, 1])
    return model

class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

def main():
    am = ArgoverseMap()

    val_dataset = read_pkl_data(val_path, batch_size=8, shuffle=False, repeat=False)

    dataset = read_pkl_data(train_path, batch_size=train_params.batch_size, repeat=True, shuffle=True)

    data_iter = iter(dataset)   
    
    model_ = create_model()
    model_.conv_fluid.bullseye_weights.data *= 2.
    model_.conv_fluid.outer_weights.data *= 2.
    for i in range(4):
        model_.convs[i].bullseye_weights.data *= 2.
        model_.convs[i].outer_weights.data *= 2.
    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    # model_ = torch.load(model_name + ".pth") 
    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    # model = model_
    model = MyDataParallel(model_).to(device)
    optimizer = torch.optim.Adam(model.parameters(), train_params.base_lr,betas=(0.9, 0.999), weight_decay=4e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.9968)
    
    def train_one_batch(model, batch, train_window=2):

        batch_size = train_params.batch_size

        inputs = ([
            batch['pos_2s'], batch['vel_2s'], 
            batch['pos0'], batch['vel0'], 
            batch['accel'], None,
            batch['lane'], batch['lane_norm'], 
            batch['car_mask'], batch['lane_mask']
        ])

        # print_inputs_shape(inputs)
        # print(batch['pos0'])
        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = batch['pos1']
        # print(pr_pos1)

        # losses = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])
        losses = 0.5 * loss_fn(pr_pos1, gt_pos1, torch.sum(batch['car_mask'], dim = -2) - 1, batch['car_mask'].squeeze(-1))
        del gt_pos1

        # pos_2s = batch['pos_2s']
        # vel_2s = batch['vel_2s']
        pos0 = batch['pos0']
        vel0 = batch['vel0']
        for i in range(train_window-1):
            pos_enc = torch.unsqueeze(pos0, 2)
            # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)
            vel_enc = torch.unsqueeze(vel0, 2)
            # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)
            # del pos_enc, vel_enc
            if teacher_forcing:
                inputs = (pos_enc, vel_enc, 
                          batch['pos'+str(i+1)], batch['vel'+str(i+1)], 
                          batch['accel'], None, batch['lane'],
                          batch['lane_norm'],batch['car_mask'], batch['lane_mask'])
                pos0, vel0 = batch['pos'+str(i+1)], batch['vel'+str(i+1)]
            else:
                inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, batch['accel'], None,
                          batch['lane'],
                          batch['lane_norm'],batch['car_mask'], batch['lane_mask'])
                pos0, vel0 = pr_pos1, pr_vel1
            # del pos_enc, vel_enc
            
            pr_pos1, pr_vel1, states = model(inputs, states)

            gt_pos1 = batch['pos'+str(i+2)]
            # clean_cache(device)

            # losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
            #                    model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])

            losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
                               torch.sum(batch['car_mask'], dim = -2) - 1, batch['car_mask'].squeeze(-1))


            # pr_pos1, pr_vel1 = pr_pos2, pr_vel2
            # print(pr_pos1)


        total_loss = 128 * torch.sum(losses,axis=0) / batch_size


        return total_loss
    
    epochs = train_params.epochs
    batches_per_epoch = train_params.batches_per_epoch   # batchs_per_epoch.  Dataset is too large to run whole data. 
    data_load_times = []  #Per batch 
    train_losses = []
    valid_losses = []
    valid_metrics_list = []
    min_loss = None

    for i in range(epochs):
        epoch_start_time = time.time()

        model.train()
        epoch_train_loss = 0 
        sub_idx = 0

        print("training ... epoch " + str(i + 1), end='')
        for batch_itr in range(batches_per_epoch):
            
            data_fetch_start = time.time()
            batch = next(data_iter)
            
            if sub_idx == 0:
                optimizer.zero_grad()
                if (batch_itr // batch_divide) % 25 == 0:
                    print("... batch " + str((batch_itr // batch_divide) + 1), end='', flush=True)
            sub_idx += 1

            batch_size = len(batch['pos0'])

            if use_lane:
                pass
            else:
                batch['lane_mask'] = [np.array([0])] * batch_size

            batch_tensor = {}
            convert_keys = (['pos' + str(i) for i in range(train_window + 1)] + 
                            ['vel' + str(i) for i in range(train_window + 1)] + 
                            ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

            for k in convert_keys:
                batch_tensor[k] = torch.tensor(np.stack(batch[k])[...,:2], dtype=torch.float32, device=device)
                
            if use_normalize_input:
                batch_tensor, max_pos = normalize_input(batch_tensor, normalize_scale, train_window)

            for k in ['car_mask', 'lane_mask']:
                batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device).unsqueeze(-1)

            for k in ['track_id' + str(i) for i in range(30)] + ['city']:
                batch_tensor[k] = batch[k]

            batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1)
            accel = torch.zeros(batch_size, 1, 2).to(device)
            batch_tensor['accel'] = accel
            del batch

            data_fetch_latency = time.time() - data_fetch_start
            data_load_times.append(data_fetch_latency)

            current_loss = train_one_batch(model, batch_tensor, train_window=train_window)
            
            if sub_idx < batch_divide:
                current_loss.backward(retain_graph=True)
            else:
                current_loss.backward()
                optimizer.step()
                sub_idx = 0
            del batch_tensor

            epoch_train_loss += float(current_loss)
            del current_loss
            clean_cache(device)

            if batch_itr == batches_per_epoch - 1:
                print("... DONE", flush=True)

        train_losses.append(epoch_train_loss)

        model.eval()
        with torch.no_grad():
            valid_total_loss, valid_metrics = evaluate(model.module, val_dataset, am=am, 
                                                       train_window=train_window, max_iter=50, 
                                                       device=device, use_lane=use_lane,
                                                       use_normalize_input=use_normalize_input, 
                                                       normalize_scale=normalize_scale, 
                                                       batch_size=val_dataset.batch_size)

        valid_losses.append(float(valid_total_loss))
        valid_metrics_list.append(valid_metrics)
        
        # torch.save(model.module, os.path.join(checkpoint_path, model_name + '_' + str(i) + ".pth"))

        if min_loss is None:
            min_loss = valid_losses[-1]

        if valid_losses[-1] < min_loss:
            min_loss = valid_losses[-1] 
            best_model = model
            torch.save(model.module, model_name + ".pth")

            #Add evaluation Metrics

        epoch_end_time = time.time()

        print('epoch: {}, train loss: {}, val loss: {}, epoch time: {}, lr: {}, {}'.format(
            i + 1, train_losses[-1], valid_losses[-1], 
            round((epoch_end_time - epoch_start_time) / 60, 5), 
            format(get_lr(optimizer), "5.2e"), model_name
        ))

        scheduler.step()
        
        with open('results/{}_val_metrics.pickle'.format(model_name), 'wb') as f:
            pickle.dump(valid_metrics_list, f)
        

def final_evaluation():
    am = ArgoverseMap()
    
    val_dataset = read_pkl_data(val_path, batch_size=8, shuffle=False, repeat=False)
    
    trained_model = torch.load(model_name + '.pth')
    trained_model.eval()
    
    with torch.no_grad():
        valid_total_loss, valid_metrics = evaluate(trained_model, val_dataset, am=am, 
                                                   train_window=train_window, max_iter=len(val_dataset), 
                                                   device=device, start_iter=200, use_lane=use_lane,
                                                   use_normalize_input=use_normalize_input, 
                                                   normalize_scale=normalize_scale)
    
    with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f:
        pickle.dump(valid_metrics, f)
        
        
if __name__ == '__main__':
    main()
    
    # final_evaluation()
    
    
    