#!/usr/bin/env python3
import os
import sys
import argparse
import numpy as np
import re
from glob import glob
import time
import importlib
import torch
import pickle


sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from datasets.dataset_reader_argoverse import read_data_val
from datasets.helper import get_lane_direction
from traffic_evaluation_helper import TrafficErrors
from argoverse.evaluation.eval_forecasting import get_ade, get_fde
from train_utils import *


def get_agent(pr, gt, pr_id, gt_id, agent_id, device='cpu'):
        
    pr_agent = pr[pr_id == agent_id,:]
    gt_agent = gt[gt_id == agent_id,:]
    
    return pr_agent, gt_agent


def evaluate(model, val_dataset, fluid_errors=None, am=None, use_lane=False,
             train_window=3, max_iter=2500, device='cpu', start_iter=0, 
             batch_size=32, use_normalize_input=False, normalize_scale=3):
    
    print('evaluating.. ', end='', flush=True)

    if fluid_errors is None:
        fluid_errors = TrafficErrors()
    
    if am is None:
        from argoverse.map_representation.map_api import ArgoverseMap
        am = ArgoverseMap()
        
    count = 0
    prediction_gt = {}
    losses = []
    val_iter = iter(val_dataset)
    
    for i, sample in enumerate(val_dataset):
        
        if i >= max_iter:
            break
        
        if i < start_iter:
            continue
        
        pred = []
        gt = []

        if count % 1 == 0:
            print('{}'.format(count + 1), end=' ', flush=True)
        
        count += 1
        
        if use_lane:
            pass
        else:
            sample['lane_mask'] = [np.array([0])] * batch_size
        
        data = {}
        convert_keys = (['pos' + str(i) for i in range(30)] + 
                        ['vel' + str(i) for i in range(30)] + 
                        ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

        for k in convert_keys:
            data[k] = torch.tensor(np.stack(sample[k])[...,:2], dtype=torch.float32, device=device)
        
        if use_normalize_input:
            data, max_pos = normalize_input(data, normalize_scale, 29)

        for k in ['track_id' + str(i) for i in range(30)] + ['city', 'agent_id', 'scene_idx']:
            data[k] = np.stack(sample[k])
        
        for k in ['car_mask', 'lane_mask']:
            data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device).unsqueeze(-1)
            
        scenes = data['scene_idx'].tolist()
            
        data['agent_id'] = data['agent_id'][:,np.newaxis]
        
        data['car_mask'] = data['car_mask'].squeeze(-1)
        accel = torch.zeros(1, 1, 2).to(device)
        data['accel'] = accel

        lane = data['lane']
        lane_normals = data['lane_norm']
        agent_id = data['agent_id']
        city = data['city']
        
        inputs = ([
            data['pos_2s'], data['vel_2s'], 
            data['pos0'], data['vel0'], 
            data['accel'], None,
            data['lane'], data['lane_norm'], 
            data['car_mask'], data['lane_mask']
        ])

        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = data['pos1']

        # l = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), data['car_mask'])
        l = 0.5 * loss_fn(pr_pos1, gt_pos1, 
                          torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))

        pr_agent, gt_agent = get_agent(pr_pos1, data['pos1'],
                                       data['track_id0'], 
                                       data['track_id1'], 
                                       agent_id, device)

        # fluid_errors.add_errors(scene_id, data['frame_id0'][0], 
        #                         data['frame_id1'][0], pr_agent, 
        #                         gt_agent)
        if use_normalize_input:
            pred.append(pr_agent.unsqueeze(1).detach().cpu() *  normalize_scale)
            gt.append(gt_agent.unsqueeze(1).detach().cpu() *  normalize_scale)
        else:
            pred.append(pr_agent.unsqueeze(1).detach().cpu())
            gt.append(gt_agent.unsqueeze(1).detach().cpu())
        del pr_agent, gt_agent
        clean_cache(device)

        # pr_direction = get_lane_direction(
        #     pr_pos1, batch['city'][batch_i], am
        # )
        # pos_2s = data['pos_2s']
        # vel_2s = data['vel_2s']
        pos0 = data['pos0']
        vel0 = data['vel0']
        for i in range(28):
            pos_enc = torch.unsqueeze(pos0, 2)
            # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)
            vel_enc = torch.unsqueeze(vel0, 2)
            # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None, 
                      data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            pr_pos1, pr_vel1, states = model(inputs, states)
            clean_cache(device)
            
            if i < train_window - 1:
                gt_pos1 = data['pos'+str(i+2)]
                # l += 0.5 * loss_fn(pr_pos1, gt_pos1,
                #                    model.num_fluid_neighbors.unsqueeze(-1), data['car_mask'])
                l += 0.5 * loss_fn(pr_pos1, gt_pos1,
                                   torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))

            pr_agent, gt_agent = get_agent(pr_pos1, data['pos'+str(i+2)],
                                           data['track_id0'], 
                                           data['track_id'+str(i+2)], 
                                           agent_id, device)

            # fluid_errors.add_errors(scene_id, data['frame_id'+str(i+1)][0], 
            #                         data['frame_id'+str(i+2)][0], pr_agent, 
            #                         gt_agent)
            if use_normalize_input:
                pred.append(pr_agent.unsqueeze(1).detach().cpu() *  normalize_scale)
                gt.append(gt_agent.unsqueeze(1).detach().cpu() *  normalize_scale)
            else:
                pred.append(pr_agent.unsqueeze(1).detach().cpu())
                gt.append(gt_agent.unsqueeze(1).detach().cpu())
            
            clean_cache(device)

         
        losses.append(l)

        predict_result = (torch.cat(pred, axis=1), torch.cat(gt, axis=1))
        for idx, scene_id in enumerate(scenes):
            prediction_gt[scene_id] = (predict_result[0][idx], predict_result[1][idx])


    # with open('prediction_20t_nomap_5k.pickle', 'wb') as f:
    #     pickle.dump(predictions, f)
    
    total_loss = 128 * torch.sum(torch.stack(losses),axis=0) / max_iter
    
    result = {}
    de = {}
    # return total_loss, prediction_gt
    
    for k, v in prediction_gt.items():
        de[k] = torch.sqrt((v[0][:,0] - v[1][:,0])**2 + 
                        (v[0][:,1] - v[1][:,1])**2)
        
    ade = []
    de1s = []
    de2s = []
    de3s = []
    for k, v in de.items():
        ade.append(np.mean(v.numpy()))
        de1s.append(v.numpy()[9])
        de2s.append(v.numpy()[19])
        de3s.append(v.numpy()[-1])
    
    result['ADE'] = np.mean(ade)
    result['ADE_std'] = np.std(ade)
    result['DE@1s'] = np.mean(de1s)
    result['DE@1s_std'] = np.std(de1s)
    result['DE@2s'] = np.mean(de2s)
    result['DE@2s_std'] = np.std(de2s)
    result['DE@3s'] = np.mean(de3s)
    result['DE@3s_std'] = np.std(de3s)

    print(result)
    print('done')

    return total_loss, prediction_gt




