#!/usr/bin/env python3
import os
import sys
import argparse
import numpy as np
import re
from glob import glob
import time
import importlib
import torch
import pickle


sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from datasets.dataset_reader_argoverse import read_data_val
from traffic_evaluation_helper import TrafficErrors
from argoverse.evaluation.eval_forecasting import get_ade, get_fde
from train_utils import *


def evaluate(model, val_dataset, fluid_errors=None, am=None, use_lane=False,
             train_window=3, max_iter=2500, device='cpu', start_iter=0, 
             batch_size=32, use_normalize_input=False, normalize_scale=3):
    
    print('evaluating.. ', end='', flush=True)

    if fluid_errors is None:
        fluid_errors = TrafficErrors()
        
    count = 0
    prediction_gt = {}
    losses = []
    val_iter = iter(val_dataset)
    
    for i, sample in enumerate(val_dataset):
        
        if i >= max_iter:
            break
        
        if i < start_iter:
            continue
        
        pred = []
        gt = []

        if count % 10 == 0:
            print('{}'.format(count + 1), end=' ', flush=True)
        
        count += 1
        
        data = {}
        convert_keys = (['pos' + str(i) for i in range(13)] + 
                        ['vel' + str(i) for i in range(13)] + 
                        ['pos_enc', 'vel_enc', 'man_mask'])

        for k in convert_keys:
            data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device)
            
        data['scene_idx'] = np.stack(sample['scene_idx'])
        scenes = data['scene_idx'].tolist()
        
        data['man_mask'] = data['man_mask'].squeeze(-1)
        # accel = torch.zeros(1, 1, 3).to(device)
        accel = data['vel0'] - data['vel_enc'][...,-1,:]
        data['accel'] = accel
        
        inputs = ([
            data['pos_enc'], data['vel_enc'], 
            data['pos0'], data['vel0'], 
            data['accel'], None, 
            data['man_mask']
        ])

        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = data['pos1']

        # l = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), data['car_mask'])
        l = 0.5 * loss_fn(pr_pos1, gt_pos1, torch.sum(data['man_mask'], dim = -1).unsqueeze(-1) - 1, data['man_mask'])

        pr_agent, gt_agent = pr_pos1[:,0], gt_pos1[:,0]
        # print(pr_agent, gt_agent)

        # fluid_errors.add_errors(scene_id, data['frame_id0'][0], 
        #                         data['frame_id1'][0], pr_agent, 
        #                         gt_agent)
        pred.append(pr_agent.unsqueeze(1).detach().cpu())
        gt.append(gt_agent.unsqueeze(1).detach().cpu())
        del pr_agent, gt_agent
        clean_cache(device)

        # pr_direction = get_lane_direction(
        #     pr_pos1, batch['city'][batch_i], am
        # )
        pos_2s = data['pos_enc']
        vel_2s = data['vel_enc']
        pos0 = data['pos0']
        vel0 = data['vel0']
        for i in range(11):
            pos_enc = torch.unsqueeze(pos0, 2)
            # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)
            vel_enc = torch.unsqueeze(vel0, 2)
            # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)
            accel = pr_vel1 - vel_enc[...,-1,:]
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, accel, None, 
                      data['man_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            pr_pos1, pr_vel1, states = model(inputs, states)
            clean_cache(device)
            
            if i < train_window - 1:
                gt_pos1 = data['pos'+str(i+2)]
                # l += 0.5 * loss_fn(pr_pos1, gt_pos1,
                #                    model.num_fluid_neighbors.unsqueeze(-1), data['car_mask'])
                l += 0.5 * loss_fn(pr_pos1, gt_pos1,
                                   torch.sum(data['man_mask'], dim = -1).unsqueeze(-1) - 1, data['man_mask'])

            pr_agent, gt_agent = pr_pos1[:,0], data['pos'+str(i+2)][:,0]
            # print(pr_agent, gt_agent)

            # fluid_errors.add_errors(scene_id, data['frame_id'+str(i+1)][0], 
            #                         data['frame_id'+str(i+2)][0], pr_agent, 
            #                         gt_agent)
            pred.append(pr_agent.unsqueeze(1).detach().cpu())
            gt.append(gt_agent.unsqueeze(1).detach().cpu())
            
            clean_cache(device)

         
        losses.append(l)

        predict_result = (torch.cat(pred, axis=1), torch.cat(gt, axis=1))
        for idx, scene_id in enumerate(scenes):
            prediction_gt[scene_id] = (predict_result[0][idx], predict_result[1][idx])


    # with open('prediction_20t_nomap_5k.pickle', 'wb') as f:
    #     pickle.dump(predictions, f)
    
    total_loss = 128 * torch.sum(torch.stack(losses),axis=0) / max_iter
    
    result = {}
    de = {}
    # return total_loss, prediction_gt
    
    for k, v in prediction_gt.items():
        de[k] = torch.sqrt((v[0][:,0] - v[1][:,0])**2 + 
                        (v[0][:,1] - v[1][:,1])**2)
        
    ade = []
    fde = []
    for k, v in de.items():
        ade.append(np.mean(v.numpy()))
        fde.append(v.numpy()[-1])
    
    result['ADE'] = np.mean(ade)
    result['ADE_std'] = np.std(ade)
    result['fde'] = np.mean(fde)
    result['fde_std'] = np.std(fde)

    print(result)
    print('done')

    return total_loss, prediction_gt





