import sys
import tensorflow as tf
import numpy as np
import random
import math
import statistics
import os
import data
import models
import cv2
from scipy.spatial.transform import Rotation as R
import argparse


def dictToArray(hypDict):  # take dictionary keypoints and return list object
    coordArray = np.zeros((len(hypDict.keys()), 2))
    for key, hyps in hypDict.items():
        coordArray[key] = np.array([round(hyps[1]), round(hyps[0])])  # x, y format
    return coordArray


def ransacVal(y1, x1, v2):  # dot product of unit vectors to find cos(theta difference)
    v2 = v2 / np.linalg.norm(v2)

    return y1 * v2[1] + x1 * v2[0]


def determineOutlier(input, yMean, yDev, xMean, xDev):
    return abs(input[0] - yMean) > yDev or abs(input[1] - xMean) > xDev


def pruneHypsStdDev(hypDict, m=2):  # prune generated hypotheses using mean and stdDev
    for key, hyps in hypDict.items():
        yVals, xVals = [x[0][0] for x in hyps], [x[0][1] for x in hyps]
        yMean, xMean = statistics.mean(yVals), statistics.mean(xVals)
        yDev, xDev = statistics.pstdev(yVals) * m, statistics.pstdev(xVals) * m
        hypDict[key] = [x for x in hyps if not determineOutlier(x[0], yMean, yDev, xMean, xDev)]


def getMean(hypDict):  # get weighted average of coordinates
    meanDict = {}
    for key, hyps in hypDict.items():
        xMean = 0
        yMean = 0
        totalWeight = 0
        for hyp in hyps:
            yMean += hyp[0][0] * hyp[1]
            xMean += hyp[0][1] * hyp[1]
            totalWeight += hyp[1]
        yMean /= totalWeight
        xMean /= totalWeight
        meanDict[key] = [yMean, xMean]
    return meanDict


def predict_pose(class_name, image, fps_points, vecModel, classModel):
    nnInput = np.array([image])

    #with tf.devica("cpu:0")
    vecPred = vecModel.predict(nnInput)[0]
    classPred = classModel.predict(nnInput)[0]
    
    # print("Vector Prediction shape: " + str(vecPred.shape))
    # print("Class Prediction shape: " + str(classPred.shape))
    # showImage(classPred)  # let's see our class prediction output
    # ====================

    #print(classPred)
    population = np.where(classPred > 0.1)[:2]  # .9
    population = list(zip(population[0], population[1]))
    print("Len Population : ", len(population))  # the number of class pixels found
    #print(population)
    # ====================

    hypDict = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: []}

    for n in range(50):  # take two pixels, find intersection of unit vectors
        # print(n)
        p1 = population.pop(random.randrange(len(population)))
        v1 = vecPred[p1[0]][p1[1]]
        p2 = population.pop(random.randrange(len(population)))
        v2 = vecPred[p2[0]][p2[1]]
        # print(p1, p2)
        # print(v1, v2)
        for i in range(9):  # find lines intersection, use as hypothesis
            m1 = v1[i * 2 + 1] / v1[i * 2]
            m2 = v2[i * 2 + 1] / v2[i * 2]
            b1 = p1[0] - p1[1] * m1
            b2 = p2[0] - p2[1] * m2
            x = (b2 - b1) / (m1 - m2)
            y = m1 * x + b1
            if (y >= p1[0] != v1[i * 2 + 1] < 0 or x >= p1[1] != v1[i * 2] < 0 or y >= p2[0] != v2[
                i * 2 + 1] < 0 or x >=
                p2[1] != v2[i * 2] < 0) or not (
                    m1 - m2):  # check if line intersection takes place according to unit vector directions
                continue
            # print(y, x)
            weight = 0
            for voter in population:  # voting for fit of hypothesis
                yDiff = y - voter[0]
                xDiff = x - voter[1]

                mag = math.sqrt(yDiff ** 2 + xDiff ** 2)
                vec = vecPred[voter[0]][voter[1]][i * 2: i * 2 + 2]

                if ransacVal(yDiff / mag, xDiff / mag, vec) > .99:
                    weight += 1
            hypDict[i].append(((y, x), weight))

        population.append(p1)
        population.append(p2)
        # print("--------------------")
    # print("Coordinate hypotheses and weights: " + str(hypDict[0]))
    # print("# Coordinate hypotheses and weights: " + str(len(hypDict[0])))

    # ================
    pruneHypsStdDev(hypDict)
    # print("# Coordinate hypotheses and weights: " + str(len(hypDict[0])))
    # ==========================

    meanDict = getMean(hypDict)
    # print(meanDict)
    # =============================

    preds = dictToArray(meanDict)[:8]
    matrix = np.array(
        [[543.25272224, 0., 320.25], [0., 724.33696299, 240.33333333], [0., 0., 1.]])  # camera matrix GUIMOD

    _, rVec, tVec = cv2.solvePnP(fps_points, preds, matrix, np.zeros(shape=[8, 1], dtype='float64'),
                                 flags=cv2.SOLVEPNP_ITERATIVE)

    return rVec, tVec


if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument("-cls_name", "--class_name", type=str,
                    help="[kiwi1, pear2, banana1, orange, peach1]", required=True)
    ap.add_argument("--path_data", type=str, required=True)
    ap.add_argument("--folder_evaluation", type=str, required=True)

    args = vars(ap.parse_args())

    class_name = args["class_name"]
    path_data = args["path_data"]
    folder_evaluation = args["folder_evaluation"]

    # class_name = 'pear'
    #basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + class_name
    basePath = f"{path_data}/{folder_evaluation}/{class_name}"
    #fps = np.loadtxt(f'Generated_Worlds_/Generated/{class_name}/{class_name}_fps_3d.txt')
    fps = np.loadtxt(f'{path_data}/Generated/{class_name}/{class_name}_fps_3d.txt')


    #getAllValDataFruits(base_path, training_folder, evaluation_folder, modelClass='cat'):
    images_ls, labels_ls, mask_ls, choice_ls = data.getAllValDataFruits(path_data, "Generated_Worlds_Training", folder_evaluation, class_name)
    
    path_images=f"{path_data}/{folder_evaluation}/{class_name}/RGB_resized"
    dataset = path_data.split('/')[-1]
    
    if not os.path.exists(f"{basePath}/Pose_prediction{dataset}_{folder_evaluation}_{class_name}"):
        os.makedirs(f"{basePath}/Pose_prediction{dataset}_{folder_evaluation}_{class_name}")


    # loading our model to predict unit vectors per pixel per keypoint on image
    vecModel = models.stvNetNew(outVectors=True, outClasses=False)
    vecModel.load_weights(f'models/stvNet_new_coords_{class_name}')  # loading weights for standard labels model
    vecModel.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.Huber())

    # loading our class model for image segmentation
    classModel = models.uNet(outVectors=False, outClasses=True)
    classModel.load_weights(f'models/uNet_classes_{class_name}')
    classModel.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy())

    #for img in os.listdir(path_images):
    for i, img in enumerate(images_ls):
        img_id = choice_ls[i].split('.png')
        #img_id = img.split('.')
        img_id = int(img_id[0])
        
        try :
            r_pre, t_pre = predict_pose(class_name, img, fps, vecModel, classModel)
            r = R.from_rotvec(r_pre.reshape(3, ))
            r_pre_mx = np.array(r.as_matrix())
            t_pre = np.array(t_pre).reshape(3, )

            res = np.zeros((3, 4))
            res[:3, :3] = r_pre_mx
            res[:3, 3] = t_pre
            #print(res)
            print("saving : ",img_id)
            np.save(f'{basePath}/Pose_prediction{dataset}_{folder_evaluation}_{class_name}/{img_id}.npy', res)  # save
        except : 
            print("The image is not good, less than 50 pix segmentation ? ")


