import torch
import numpy as np
from tqdm import tqdm
from code.data.datasets.posenet import cal_quat_angle_error


class SevenScenesEvaluator:
    @staticmethod
    def evaluate(model, data_loader, device):
        model.eval()
        with torch.no_grad():
            q_err_all, t_err_all = [], []
            encoder = model['encoder']
            t_decoder, q_decoder = model['left'], model['right']
            for data in tqdm(data_loader, total=len(data_loader)):
                t_gt, q_gt = data['t_gt'].to(device), data['q_gt'].to(device)
                img = data['img'].to(device)

                enc_out = encoder(img)
                t_est, q_est = t_decoder(enc_out), q_decoder(enc_out)

                # compute errors for translation and orientation
                t_err = torch.norm(t_gt - t_est, p=2, dim=1)
                q_err = torch.FloatTensor(cal_quat_angle_error(q_gt.cpu().numpy(), q_est.cpu().numpy())).to(device)

                q_err_all.append(q_err)
                t_err_all.append(t_err)

            q_err_all = torch.cat(q_err_all).cpu().numpy()
            t_err_all = torch.cat(t_err_all).cpu().numpy()

        t_err_med, q_err_med = np.median(t_err_all), np.median(q_err_all)
        print(f'Acc (median error): pos: {t_err_med}, ornt: {q_err_med}')
        return t_err_med, q_err_med
