
import torch
from tqdm import tqdm
from data.data_process import ETHDataset, NBADataset, SDDDataset, JrdbDataset, seq_collate, jrdb_seq_collate
from torch.utils.data import DataLoader
from tool.tools import ETHdata_process, SDDdata_process, NBAdata_process, JRDBdata_process, dct_transform


def calculate_collision(cfg, fut_traj, pred_traj, seq_start_end_list, threshold=0.1):
    # calculate the best one with the smallest error
    distances = torch.norm(pred_traj - fut_traj.unsqueeze(1), dim=-1) * cfg.data_scale
    distances = torch.sum(distances, dim=-1)
    best_index = torch.argmin(distances, dim=-1)
    best_preds = [pred_traj[i][j] for i, j in enumerate(best_index)]
    best_preds = torch.stack(best_preds, dim=0)
    # # use fut_traj instead pred_traj to calculate the GT Collision Rate
    # best_preds = fut_traj
    total_collisions = 0
    total_possible_collisions = 0
    for seq_num, seq in enumerate(seq_start_end_list):
        start, end = seq[0], seq[1]
        trajs = best_preds[start:end]
        N = end - start
        for i in range(N):
            for j in range(i + 1, N):
                # calculate the distance between traj[i] and traj[j]
                distances_ = torch.norm(trajs[i] - trajs[j], dim=-1) * cfg.data_scale
                # calculate the occurred collisions
                collisions = torch.sum(distances_ < threshold)
                total_collisions += collisions
        # calculate the total possible collision frequency
        # total_possible_collisions_ = (N * (N - 1) // 2) * cfg.pred_len
        total_possible_collisions_ = torch.div(N * (N - 1), 2, rounding_mode='floor') * cfg.pred_len
        total_possible_collisions += total_possible_collisions_
    # 计算碰撞率
    # collision_rate = total_collisions / total_possible_collisions

    return total_collisions, total_possible_collisions



def val_jrdb(cfg, model, scene=None, dct_m=None, idct_m=None):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    dct_m = dct_m.float().to(device)

    # Prepare for Loading Dataset, opts: 'eth', 'sdd', 'nba'
    if cfg.dataset_type == 'ETH':
        test_dataset = ETHDataset(data_path=cfg.data_path, obs_len=cfg.obs_len, pred_len=cfg.pred_len,
                                      mode='test', t=cfg.Iter, k=cfg.pattern, e=cfg.epsilon, scene_names=scene)
    elif cfg.dataset_type == 'SDD':
        test_dataset = SDDDataset(data_path=cfg.data_path, obs_len=cfg.obs_len, pred_len=cfg.pred_len, t=cfg.Iter,
                                           k=cfg.pattern, e=cfg.epsilon, flip_aug=False, mode='test')
    elif cfg.dataset_type == 'NBA':
        test_dataset = NBADataset(obs_len=10, pred_len=20, t=cfg.Iter, k=cfg.pattern, e=cfg.epsilon, training=False)
    elif cfg.dataset_type == 'JRDB':
        test_dataset = JrdbDataset(obs_len=cfg.obs_len, pred_len=cfg.pred_len, training=False)

    # Package dataset for training and testing
    if cfg.dataset_type == 'ETH':
        test_loader = DataLoader(test_dataset, batch_size=cfg.test_bs, shuffle=False, num_workers=1,
                                      collate_fn=seq_collate)
    else:
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=1, collate_fn=jrdb_seq_collate)

    if not cfg.dataset_type == 'JRDB':
        performance = {'APD': 0, 'FDE': 0, 'ADE': 0,
                       'JADE': 0, 'JFDE': 0}
    else:
        performance = {'FDE': [0, 0, 0, 0],
                       'ADE': [0, 0, 0, 0],
                       'JADE': [0, 0, 0, 0],
                       'JFDE': [0, 0, 0, 0]}

    # log = open(os.path.join('../logs', 'log.txt'), 'a+')
    # print_log(model_path, log=log)
    exp = 'Test'
    samples = 0
    count = 0
    val_count = 0
    acc = 0
    total_collisions = 0
    total_possible_collisions = 0
    if cfg.dataset_type == 'ETH':
        scene = scene[0]
    model.eval()
    with torch.no_grad():
        for _, data in enumerate(tqdm(test_loader)):
            if cfg.dataset_type == 'ETH':
                obs_traj, _, fut_traj, _ = ETHdata_process(cfg, data)
            elif cfg.dataset_type == 'SDD':
                obs_traj, fut_traj, traj_mask, initial_pos = SDDdata_process(cfg, data, mode='train')
            elif cfg.dataset_type == 'NBA':
                obs_traj, fut_traj, traj_mask, initial_pos = NBAdata_process(cfg, data)
            elif cfg.dataset_type == 'JRDB':
                past_traj, pad_traj, fut_traj, traj_mask, initial_pos = JRDBdata_process(cfg, data)

            dct_traj = dct_transform(pad_traj, dct_m, cfg.n_pre)
            mod_dct_traj = dct_transform(pad_traj, dct_m,
                                         cfg.n_pre)
            fut_pred, moe_outputs = model(past_traj, dct_traj, exp=exp, mod=mod_dct_traj)

            distances = torch.norm(fut_pred - fut_traj.unsqueeze(1), dim=-1) * cfg.data_scale

            for time_i in range(1, 5):
                ade = (distances[:, :, :3 * time_i]).mean(dim=-1).min(dim=-1)[0].sum()
                fde = (distances[:, :, 3 * time_i - 1]).min(dim=-1)[0].sum()
                jade = (distances[:, :, :3 * time_i]).mean(dim=-1).mean(dim=0).min()
                jfde = (distances[:, :, 3 * time_i - 1]).mean(dim=0).min()

                # self.performance['APD'] += diversity.item()
                performance['ADE'][time_i - 1] += ade.item()
                performance['FDE'][time_i - 1] += fde.item()
                performance['JADE'][time_i - 1] += jade.item()
                performance['JFDE'][time_i - 1] += jfde.item()

            samples += distances.shape[0]

            val_count += 1

    for time_i in range(4):
        print(
            'Evaluation\t--ADE({}s): {:.4f}\t--FDE({}s): {:.4f}\t--JADE({}s): {:.4f}\t--JFDE({}s): {:.4f}'.format(
                1.2*(time_i + 1), performance['ADE'][time_i] / samples,
                1.2*(time_i + 1), performance['FDE'][time_i] / samples,
                1.2*(time_i + 1), performance['JADE'][time_i] / len(test_loader),
                1.2*(time_i + 1), performance['JFDE'][time_i] / len(test_loader)))
