
import os

import torchvision.utils

# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import scipy.io as scio
import gen_BEV.utils as utils
import ssl
import math
ssl._create_default_https_context = ssl._create_unverified_context  # for downloading pretrained VGG weights
from visualize_utils import line_point, save_img
from RANSAC_lib.euclidean_trans import Least_Squares_weight, rt2edu_matrix, Least_Squares
import numpy as np
import os
import argparse
from models_vigor import BEV_corr
import random
from gen_BEV.utils import gps2distance
import time
from dataLoader.Vigor_dataset import load_vigor_data
from op_flow.loss_fun import  fetch_optimizer,corr_test_loss

try:
    from torch.cuda.amp import GradScaler
except:
    # dummy GradScaler for PyTorch < 1.6
    class GradScaler:
        def __init__(self):
            pass
        def scale(self, loss):
            return loss
        def unscale_(self, optimizer):
            pass
        def step(self, optimizer):
            optimizer.step()
        def update(self):
            pass

def coords_grid(batch, ht, wd, device):
    coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
    coords = torch.stack(coords[::-1], dim=0).float()
    return coords[None].repeat(batch, 1, 1, 1)

#batch_size 1
def test1(net_test, args, save_path, best_rank_result, epoch):
    ### net evaluation state
    net_test.eval()

    dataloader= load_vigor_data(args.batch_size, area=args.area, rotation_range=args.rotation_range,
                                                 train=False, weak_supervise=True)


    pred_shifts = []
    pred_headings = []

    gt_shifts = []
    gt_headings = []

    # RANSAC_E = RANSAC(0.5)
    # LS_weight = Least_Squares_weight(50)
    #'epe', '5px', '15px', '25px', '50px'
    test_met = [0,0,0,0,0]

    start_time = time.time()
    for i, data in enumerate(dataloader, 0):
        grd, sat, gt_shift_u, gt_shift_v, gt_heading, meter_per_pixel = [item.cuda() for item in data]        
        gt_shift_u = gt_shift_u[:,None]
        gt_shift_v = gt_shift_v[:,None]
        gt_heading = -gt_heading[:,None]
        vis_heading = gt_heading * args.rotation_range 
        vis_u = -gt_shift_u 
        vis_v = gt_shift_v 
        s_gt_u = gt_shift_u 
        s_gt_v = gt_shift_v
        s_gt_theta = gt_heading * args.rotation_range
        gt_shift_u = gt_shift_u*meter_per_pixel[:,None]
        gt_shift_v = gt_shift_v*meter_per_pixel[:,None]

        if args.end2end == 0:
            flow_predictions, flow_conf, mask = net(sat, grd, meter_per_pixel, \
                                                        torch.zeros_like(gt_shift_u,device=gt_shift_u.device), \
                                                        torch.zeros_like(gt_shift_v,device=gt_shift_v.device), \
                                                        torch.zeros_like(gt_heading,device=gt_heading.device), mode='train', \
                                                        file_name=0)
            

        if args.end2end == 1:
            # if i==29:
            #     print(2112)
            flow_predictions, flow_conf, mask, pre_u, pre_v, pre_theta = net(sat, grd, meter_per_pixel, \
                                                        torch.zeros_like(gt_shift_u,device=gt_shift_u.device), \
                                                        torch.zeros_like(gt_shift_v,device=gt_shift_v.device), \
                                                        torch.zeros_like(gt_heading,device=gt_heading.device), mode='train', \
                                                        file_name=0)

            # if i==29:
            # #
            #     view_batch = 0
            #     save_img(grd_img[0], 'result_visualize/grd_ori.jpg')
            #     coords0 = coords_grid(mask.size()[0], mask.size()[2], mask.size()[3], device=mask.device)
            #     save_img(sat_map_gt[0], 'result_visualize/sat_ori.jpg')
            #     save_img(sat_map[0], 'result_visualize/sat_noise.jpg')
            #     coords_gt = match(mask, vis_heading, vis_u, vis_v, coords0)
            #     # grd_solve('result_visualize/sat_ori.jpg', 'result_visualize/grd_ori.jpg', left_camera_k[0:1,:,:].cpu(),'result_visualize/Conf/')
            #     # show_feature_map(flow_conf[-1][0], 'result_visualize/Conf/')
            #     # overly('result_visualize/Conf/BEV.jpg', 'result_visualize/Conf/0.jpg', 'result_visualize/Conf/')
            #     # gt
            #     coords0_gt = coords0[0].permute(1,2,0)
            #     coords1_gt = (coords_gt[0]).permute(1,2,0)
            #     match_x = []
            #     match_y = []
            #     # x = [256, 256+512-vis_u.data.float().cpu()]
            #     # y = [256, 256+vis_v.data.float().cpu()]
            #     # match_x.append(x)
            #     # match_y.append(y)
            #     for h in range(coords0_gt.size()[0]):
            #         for w in range(coords1_gt.size()[1]):
            #             if ((h == 270 and w == 340) or \
            #                 (random.randint(0,8000) == 1)) and mask[0,0,h,w]:
            #                 x = [coords0_gt[h][w][0].data.float().cpu(), coords1_gt[h][w][0].data.float().cpu() + coords1_gt.size()[1]]
            #                 y = [coords0_gt[h][w][1].data.float().cpu(), coords1_gt[h][w][1].data.float().cpu()]
            #                 match_x.append(x)
            #                 match_y.append(y)
            #     line_point('result_visualize/sat_ori.jpg', 'result_visualize/sat_noise.jpg', match_x, match_y,None, 'line_gt.jpg')

            #     match_x = []
            #     match_y = []
            #     # x = [256, 256+512-vis_u.data.float().cpu()]
            #     # y = [256, 256+vis_v.data.float().cpu()]
            #     # match_x.append(x)
            #     # match_y.append(y)
            #     coords1_gt = (coords0[view_batch]+flow_predictions[-1][view_batch]).permute(1,2,0)
            #     for h in range(coords0_gt.size()[0]):
            #         for w in range(coords1_gt.size()[1]):
            #             if ((h == 270 and w == 340) or \
            #                 (random.randint(0,8000) == 1)) and mask[0,0,h,w]:
            #                 x = [coords0_gt[h][w][0].data.float().cpu(), coords1_gt[h][w][0].data.float().cpu() + coords1_gt.size()[1]]
            #                 y = [coords0_gt[h][w][1].data.float().cpu(), coords1_gt[h][w][1].data.float().cpu()]
            #                 match_x.append(x)
            #                 match_y.append(y)
            #     line_point('result_visualize/sat_ori.jpg', 'result_visualize/sat_noise.jpg', match_x, match_y,None, 'line_pre.jpg')

        shifts = torch.cat([pre_v*meter_per_pixel[:,None].float(), pre_u*meter_per_pixel[:,None].float()], dim=-1)
        gt_shift = torch.cat([gt_shift_v, gt_shift_u], dim=-1)
        pred_shifts.append(shifts.data.cpu().numpy())
        gt_shifts.append(gt_shift.data.cpu().numpy())

        pred_headings.append(pre_theta.data.cpu().numpy())
        gt_headings.append(s_gt_theta.data.cpu().numpy())
        

        if args.test_flow:
            coords0 = coords_grid(mask.size()[0], mask.size()[2], mask.size()[3], device=mask.device)
            coords1 = match(mask, vis_heading, vis_u, vis_v, coords0)
            flow_gt = coords1 - coords0
            flow_gt = flow_gt*mask
            flow_predictions = flow_predictions
            loss, metrics = corr_test_loss(flow_predictions, flow_gt, mask.repeat(1,2,1,1), args.gamma)#?

        # for z in range(len(meters)):
        #     if abs(pre_u - s_gt_u) < meters[z]:
        #         in_lon_num[z] = in_lon_num[z] + 1
        #     if abs(pre_v - s_gt_v) < meters[z]:
        #         in_lat_num[z] = in_lat_num[z] + 1
        #     if abs(pre_theta - s_gt_heading) < angles[z]:
        #         in_angles[z] = in_angles[z] + 1

        if args.test_flow:
            j = 0
            for key in metrics.keys():
                test_met[j] = test_met[j] + metrics[key]
                j = j + 1

        if i % 20 == 0:
            # pred_shifts_ori = pred_shifts
            # pred_headings_ori = pred_headings
            # gt_shifts_ori = gt_shifts
            # gt_headings_ori = gt_headings
            # pred_shifts = np.concatenate(pred_shifts, axis=0)
            # pred_headings = np.concatenate(pred_headings, axis=0)
            # gt_shifts = np.concatenate(gt_shifts, axis=0)
            # gt_headings = np.concatenate(gt_headings, axis=0)
            # diff_shifts = np.abs(pred_shifts - gt_shifts)
            # pred = np.sum(diff_shifts[:, 0] < 3) / diff_shifts.shape[0] * 100
            # line = 'lateral      within ' + str(3) + ' meters (pred, init): ' + str(pred) 
            # print(line)
            # pred = np.sum(diff_shifts[:, 1] < 3) / diff_shifts.shape[0] * 100
            # line = 'longitudinal within ' + str(3) + ' meters (pred, init): ' + str(pred)
            # print(line)
            # angle_diff = np.remainder(np.abs(pred_headings - gt_headings), 360)
            # idx0 = angle_diff > 180
            # angle_diff[idx0] = 360 - angle_diff[idx0]
            # pred = np.sum(angle_diff < 3) / angle_diff.shape[0] * 100
            # line = 'angle within ' + str(3) + ' degrees (pred, init): ' + str(pred)
            # print(line)
            # pred_shifts = pred_shifts_ori
            # pred_headings = pred_headings_ori
            # gt_shifts = gt_shifts_ori
            # gt_headings = gt_headings_ori
            print(i,"/",len(dataloader))


    end_time = time.time()
    duration = (end_time - start_time)/len(dataloader)

    pred_shifts = np.concatenate(pred_shifts, axis=0)
    pred_headings = np.concatenate(pred_headings, axis=0)
    gt_shifts = np.concatenate(gt_shifts, axis=0)
    gt_headings = np.concatenate(gt_headings, axis=0)

    distance = np.sqrt(np.sum((pred_shifts - gt_shifts) ** 2, axis=1))
    angle_diff = np.remainder(np.abs(pred_headings - gt_headings), 360)
    idx0 = angle_diff > 180
    angle_diff[idx0] = 360 - angle_diff[idx0]

    init_dis = np.sqrt(np.sum(gt_shifts ** 2, axis=1))
    init_angle = np.abs(gt_headings)

    metrics = [1, 3, 5]
    angles = [1, 3, 5]
    if args.test_flow:
        print('Time per image (second): ' + str(duration))
        print('epe:{:.3f}'.format(test_met[0]/len(dataloader)))
        print('5px:{:.3f}'.format(test_met[1]/len(dataloader)*100))
        print('15px:{:.3f}'.format(test_met[2]/len(dataloader)*100))
        print('25px:{:.3f}'.format(test_met[3]/len(dataloader)*100))
        print('50px:{:.3f}'.format(test_met[4]/len(dataloader)*100))

    if not os.path.exists(save_path):
            os.makedirs(save_path)

    file_name = save_path+"/Test1_results.txt"
    f = open(os.path.join(file_name), 'a')
    f.write('====================================\n')
    f.write('       EPOCH: ' + str(epoch) + '\n')
    f.write('Time per image (second): ' + str(duration) + '\n')
    f.write('Validation results:' + '\n')
    f.write('Pred distance average: ' + str(np.mean(distance)) + '\n')
    f.write('Pred distance median: ' + str(np.median(distance)) + '\n')
    f.write('Pred angle average: ' + str(np.mean(angle_diff)) + '\n')
    f.write('Pred angle median: ' + str(np.median(angle_diff)) + '\n')
    print('====================================')
    print('       EPOCH: ' + str(epoch))
    print('Time per image (second): ' + str(duration) + '\n')
    print('Validation results:')
    print('Init distance average: ', np.mean(init_dis))
    print('Pred distance average: ', np.mean(distance))
    print('Pred distance median: ', np.median(distance))
    print('Init angle average: ', np.mean(init_angle))
    print('Pred angle average: ', np.mean(angle_diff))
    print('Pred angle median: ', np.median(angle_diff))

    for idx in range(len(metrics)):
        pred = np.sum(distance < metrics[idx]) / distance.shape[0] * 100
        init = np.sum(init_dis < metrics[idx]) / init_dis.shape[0] * 100

        line = 'distance within ' + str(metrics[idx]) + ' meters (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')
    print('-------------------------')
    f.write('------------------------\n')

    diff_shifts = np.abs(pred_shifts - gt_shifts)
    for idx in range(len(metrics)):
        pred = np.sum(diff_shifts[:, 0] < metrics[idx]) / diff_shifts.shape[0] * 100
        init = np.sum(np.abs(gt_shifts[:, 0]) < metrics[idx]) / init_dis.shape[0] * 100

        line = 'lateral      within ' + str(metrics[idx]) + ' meters (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

        pred = np.sum(diff_shifts[:, 1] < metrics[idx]) / diff_shifts.shape[0] * 100
        init = np.sum(np.abs(gt_shifts[:, 1]) < metrics[idx]) / diff_shifts.shape[0] * 100

        line = 'longitudinal within ' + str(metrics[idx]) + ' meters (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

    print('-------------------------')
    f.write('------------------------\n')
    for idx in range(len(angles)):
        pred = np.sum(angle_diff < angles[idx]) / angle_diff.shape[0] * 100
        init = np.sum(init_angle < angles[idx]) / angle_diff.shape[0] * 100
        line = 'angle within ' + str(angles[idx]) + ' degrees (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

    print('-------------------------')
    f.write('------------------------\n')

    for idx in range(len(angles)):
        pred = np.sum((angle_diff[:, 0] < angles[idx]) & (diff_shifts[:, 0] < metrics[idx])) / angle_diff.shape[0] * 100
        init = np.sum((init_angle[:, 0] < angles[idx]) & (np.abs(gt_shifts[:, 0]) < metrics[idx])) / angle_diff.shape[0] * 100
        line = 'lat within ' + str(metrics[idx]) + ' & angle within ' + str(angles[idx]) + \
               ' (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

    print('====================================')
    f.write('====================================\n')
    f.close()

def match(mask, rot, tran_x, tran_y, coords0):
    B,C,H,W = mask.size()
    coords1 = []

    for i in range(B):
        coords1_i = (coords0.clone().permute(0,2,3,1))[i][None,:]
        ones = torch.ones((1,H,W,1)).to(coords1_i.device)
        coords1_i = torch.cat((coords1_i, ones), dim=-1)
        coords1_i = coords1_i.view(1*H*W, 3, 1)

        rot1 = rot[i][None,:]/180*math.pi
        cos = torch.cos(rot1)
        cos = cos[:,None]
        sin = torch.sin(rot1)
        sin = sin[:,None]
        zero = torch.zeros_like(sin)
        ones = torch.ones_like(sin)
        tran_x1 = tran_x[i][None,:][:,None]
        tran_y1 = tran_y[i][None,:][:,None]

        rol_tra0 = torch.cat((cos,sin,zero),dim=-1)
        rol_tra1 = torch.cat((-sin,cos,zero),dim=-1)
        rol_tra2 = torch.cat((zero,zero,ones),dim=-1)
        rol_tra = torch.cat((rol_tra0,rol_tra1,rol_tra2),dim = 1)
        rol_tra = rol_tra.repeat(H*W, 1, 1)

        rol_center0 = torch.cat((ones,zero,ones*(-H/2)),dim=-1)
        rol_center1 = torch.cat((zero,ones,ones*(-W/2)),dim=-1)
        rol_center2 = torch.cat((zero,zero,ones),dim=-1)
        rol_center = torch.cat((rol_center0,rol_center1,rol_center2),dim = 1)
        rol_center = rol_center.repeat(H*W, 1, 1)

        tra0 = torch.cat((ones,zero,(-tran_x1)),dim=-1)
        tra1 = torch.cat((zero,ones,(tran_y1)),dim=-1)
        tra2 = torch.cat((zero,zero,ones),dim=-1)
        tra = torch.cat((tra0, tra1, tra2),dim = 1)
        tra = tra.repeat(H*W, 1, 1)

        points = torch.rand((B,3)).to(mask.device)
        points_tran = ((torch.inverse(rol_center))@rol_tra@rol_center@tra@coords1_i)
        #points_tran = (tra@coords1)

        coords1_i = (points_tran[:,:2,:]).view(1, H, W, 2).permute(0,3,1,2)
        coords1.append(coords1_i)
    
    coords1 = torch.cat(coords1,dim=0)
    return coords1

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
    parser.add_argument('--resume', type=int, default=0, help='resume the trained model')
    parser.add_argument('--area', type=str, default='same', help='same or cross')
    
    parser.add_argument('--test', type=int, default=1, help='test with trained model')
    parser.add_argument('--test_flow', type=int, default=0, help='test with trained model')
    parser.add_argument('--debug', type=int, default=0, help='debug to dump middle processing images')

    #DPP
    parser.add_argument('--dpp', type=bool, default=0)
    parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training')
    parser.add_argument('--n_gpus', type=int, default=2, help='node rank for distributed training')

    parser.add_argument('--end2end', type=bool, default=1)

    parser.add_argument('--epochs', type=int, default=30, help='number of training epochs')

    parser.add_argument('--stereo', type=int, default=0, help='use left and right ground image')
    parser.add_argument('--sequence', type=int, default=1, help='use n images merge to 1 ground image')

    parser.add_argument('--rotation_range', type=float, default=180, help='degree')
    parser.add_argument('--shift_range_lat', type=float, default=20., help='meters')
    parser.add_argument('--shift_range_lon', type=float, default=20., help='meters')

    parser.add_argument('--coe_shift_lat', type=float, default=100., help='meters')
    parser.add_argument('--coe_shift_lon', type=float, default=100., help='meters')
    parser.add_argument('--coe_heading', type=float, default=100., help='degree')
    parser.add_argument('--coe_L1', type=float, default=100., help='feature')
    parser.add_argument('--coe_L2', type=float, default=100., help='meters')
    parser.add_argument('--coe_L3', type=float, default=100., help='degree')
    parser.add_argument('--coe_L4', type=float, default=100., help='feature')

    parser.add_argument('--metric_distance', type=float, default=5., help='meters')
    parser.add_argument('--loss_method', type=int, default=0, help='0, 1, 2, 3')

    parser.add_argument('--level', type=int, default=-1, help='2, 3, 4, -1, -2, -3, -4')
    parser.add_argument('--N_iters', type=int, default=5, help='any integer')
    parser.add_argument('--using_weight', type=int, default=0, help='weighted LM or not')
    parser.add_argument('--damping', type=float, default=0.1, help='coefficient in LM optimization')
    parser.add_argument('--train_damping', type=int, default=0, help='coefficient in LM optimization')

    # parameters below are used for the first-step metric learning traning
    parser.add_argument('--negative_samples', type=int, default=32, help='number of negative samples '
                                                                         'for the metric learning training')
    parser.add_argument('--use_conf_metric', type=int, default=0, help='0  or 1 ')

    parser.add_argument('--direction', type=str, default='S2GP', help='G2SP' or 'S2GP')
    parser.add_argument('--Load', type=int, default=0, help='0 or 1, load_metric_learning_weight or not')
    parser.add_argument('--Optimizer', type=str, default='LM', help='LM or SGD or ADAM')

    parser.add_argument('--level_first', type=int, default=0, help='0 or 1, estimate grd depth or not')
    parser.add_argument('--proj', type=str, default='geo', help='geo, polar, nn')
    parser.add_argument('--use_gt_depth', type=int, default=0, help='0 or 1')

    parser.add_argument('--dropout', type=int, default=0, help='0 or 1')
    parser.add_argument('--use_hessian', type=int, default=0, help='0 or 1')

    parser.add_argument('--visualize', type=int, default=0, help='0 or 0')

    parser.add_argument('--beta1', type=float, default=0.9, help='coefficients for adam optimizer')
    parser.add_argument('--beta2', type=float, default=0.999, help='coefficients for adam optimizer')

    parser.add_argument('--lr', type=float, default=0.00002, help='learning rate')  # 1e-2
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--wdecay', type=float, default=.00005)
    parser.add_argument('--epsilon', type=float, default=1e-8)
    parser.add_argument('--num_steps', type=int, default=110000)
    parser.add_argument('--iters', type=int, default=12)
    parser.add_argument('--gamma', type=int, default=0.8)
    parser.add_argument('--clip', type=float, default=1.0)
    
    args = parser.parse_args()

    return args


def getSavePath(args):
    save_path = './ModelsKitti/VIGOR/geometry_' + str(args.area) \
                + '/lat' + str(args.shift_range_lat) + 'm_lon' + str(args.shift_range_lon) + 'm_rot' +  str(args.rotation_range)

    print('save_path:', save_path)

    return save_path


if __name__ == '__main__':
    np.random.seed(2022)

    args = parse_args()

    mini_batch = args.batch_size

    save_path = getSavePath(args)

    net = eval("BEV_corr")(args)

    ### cudaargs.epochs, args.debug
    #net = torch.nn.DataParallel(net)
    net = torch.nn.DataParallel(net.cuda(), device_ids = [0])
    ###########################

    if args.test:
        test_model = [27]
        for i in test_model:
            print("test"+str(i))
            net.load_state_dict(torch.load(os.path.join(save_path, 'model_'+str(i)+'.pth'), map_location='cpu'))
            test1(net, args, save_path, 0., epoch = str(i))


    else:
        print(0)

