# -*- coding: utf-8 -*-
"""
This is an experiment to test the theortical error lower bound predicted 
in the paper. For any fixed PointNet, it is shown in the error lower-bound
proof how to produce point clouds which exhibit errors larger than a predicted
threshold. Here we implement this algorithm for a 2D PointNet produced in the
train.py code.

This tests the lower bound for a given saved model. In the main function we 
collect the experiment data to be used for plotting and call get_lower_bound()
and lower_bound_test().
"""

import numpy as np
import torch
import matplotlib.pyplot as plt
import sampler as smp

def main(model, directory, model_type, k, tau, N_tests):  
    # N_tests: number of tests of the theoretical lower bound
    
    # Loop to create many examples 
    clds, errors, dists, slacks = [], [], [], []
    for _ in range(N_tests):
        
        out = lower_bound_test(model, k, tau, verbose=False)
        A_argmax, error_max, dist_pq, slack = out
        
        clds.append(A_argmax)    # failure-mode cloud
        errors.append(error_max) # failure-mode cloud's error
        dists.append(dist_pq)    # distance of seed points
        slacks.append(slack)     # how close to theoretical lower-bound
    
    # Compute closest call (minimum of errors-lowerbound)
    # Note: This is a diagnostic number to check that all discovered errors are 
    #       indeed worse than the error lower-bound. Does not indicate how
    #       poorly the PointNet model performs on the task.
    slack_min = min(slacks)
    print('Closest call yields error - lowerbound = {}'.format(slack_min))
    
    # Plot lower-bound curve: 
    # x_lb: distance between p & q
    # y_lb: lower bound for x_lb, 
    x_lb = np.linspace(0,2,100) 
    y_lb = get_lower_bound(x_lb, k, tau)
    y_lb_optimal = get_lower_bound(x_lb, k, optimal=True)
    
    # Create, save, & plot training curve
    fig_path = directory + '/ErrorLowerBound-ptnet' + model_type + '_tau{}.png'.format(tau)
    
    plt.figure()
    plt.plot(x_lb, y_lb, linewidth=2.5,
             label="Lower Bound $\\tau=${}".format(tau), color='orange')
    #plt.plot(x_lb, y_lb_optimal, label="Lower Bound ($\\tau=0$)", color='orange', alpha=0.4)
    plt.scatter(dists, errors, label="Problematic Example Error", marker='.', 
                s=0.9, alpha=1)
    plt.xlabel('$\|\| p - q \|\|$')
    plt.ylabel('Error')
    plt.title('Problematic Example Errors vs Predicted Error Lower Bound\n '
              + 'PointNet(' + model_type + ')' )
    plt.legend(loc='best')
    plt.savefig(fig_path, dpi=300)
    plt.show()
        
def get_lower_bound(dist_pq, k, tau=0.5, return_eps=False, optimal=False):
    # Compute theoretical sup-norm lower bound for dist_pq, k>0, and tau.
    # Here gap = theoretical max distance between centers in the limit.
    # Recall that:
    #  If 0 < tau < 1 and
    #    eps = tau * (k-2)/4k * dist(p,q) = tau * gap/4
    #  Then 
    #    sup-norm error >= lower_bound (proved in paper)
    #  Where 
    #    lower_bound = (k-2)/2k * dist(p,q) - 2*eps = (1-tau) * gap/2
    #
    # If optimal=True, just return the ultimate lower-bound for tau=0
    
    if k<=0:
        # Check for valid k>0
        raise Exception('Requires k>0 but value was %d' % k)
        
    if optimal:
        # Equivalent to tau=0.0
        gap = (k-2) * dist_pq / k
        return gap/2
    
    elif tau <= 0 or tau >= 1:
        # We allow don't allow tau=0 the algorithm might not converge.
        raise Exception('Requires 0< tau < 1 but value was {}'.format(tau))
    
    gap = (k-2) * dist_pq / k
    eps = tau * gap/4
    lower_bound = (1-tau) * gap/2
    
    if return_eps:
        return eps, lower_bound
    else:
        return lower_bound
    
    

def lower_bound_test(model, k=10, tau=0.5, verbose=False):
    # Here we test the theoretical lower bound of the paper for point clouds
    # of size k. We need 0 <= tau < 1.
    
    if tau <= 0 or tau >= 1:
         # We exclude tau=0 because the algorithm might not converge.
        raise Exception('Requires 0< tau < 1. The value of tau was: {}'.format(tau))
    
    # lower_bound_test returns a point cloud A_argmax, the sup-norm error found 
    # for A_argmax, and the distance d_pq between the anchor points p & q.
    
    
    #step 0: Pick points p,q and create associated point cloud.
    pq_cld = torch.FloatTensor(smp.disc(N=2)).view(1,-1,2)
    p, q = pq_cld[0][0], pq_cld[0][1]
    
    
    #step 1: Commpute predicted & true center of mass for {p,q}.
    m_pq = model(pq_cld)
    c_pq = (p+q)/2
    pq_error = torch.norm(m_pq - c_pq)
    
    
    #step 2: Compute theoretical sup-norm lower bound for simple eps.
    #  Here gap = theoretical distance between centers in the limit:
    #  If 0 < tau <= 1 and
    #    eps = (1-tau) * (k-2)/4k * dist(p,q) = tau * gap/4
    #  Then 
    #    sup-norm error >= lower_bound
    #  Where 
    #    lower_bound = (k-2)/2k * dist(p,q) - 2*eps = (1-tau) * gap/2
    
    dist_pq = np.linalg.norm(p - q)
    eps, lower_bound = get_lower_bound(dist_pq, k, tau, return_eps=True)
    
    
    #step 3: Pick k-2 more points to make set A_smp.
    A_smp = torch.FloatTensor(smp.disc(N = k-2))
    maxdist_A2p = torch.norm(A_smp - p, dim=1).max()
    maxdist_A2q = torch.norm(A_smp - q, dim=1).max()
    
    #step 4: Define Ap to the be points of A pulled close enough to p so that
    # Ap is within eps of {p,q} & f(Ap U {q}) is within eps of f({p,q}). 
    # Define Aq with the roles of p & q reversed. This is possible because the 
    # sequence of Ap U {q} converges to {p,q}, and the same goes for Aq U {p}.

    def make_Ap(A_smp, t):
        # 0 < t <= 1
        A_shrunk = p + t * eps * (A_smp - p) / maxdist_A2p
        Ap = torch.cat([ A_shrunk, pq_cld[0] ]).view(1,-1,2)
        return Ap
    def make_Aq(A_smp, t):
        # 0 < t <= 1
        A_shrunk = q + t * eps * (A_smp - q) / maxdist_A2q
        Aq = torch.cat([ A_shrunk, pq_cld[0] ]).view(1,-1,2)
        return Aq
    
    # Find value of t that makes the absolute difference between true and
    # predicted center-of-mass less than eps. By arguments from the proof,  
    # we know such a t exists. The below approach is worst cases O(n) and 
    # can be improved by a bisection method, but works fine for this use-case.
# =============================================================================
#     for t in np.linspace(1,0,1000):
#         Ap = make_Ap(A_smp, t)
#         m_Ap = model(Ap) # model predicted CoM for A_p
#         if torch.norm( m_Ap - m_pq ) < eps:
#             break
#     for t in np.linspace(1,0,1000):
#         Aq = make_Aq(A_smp, t)
#         m_Aq = model(Aq) # model predicted CoM for A_q
#         if torch.norm( m_Aq - m_pq ) < eps:
#             break
# =============================================================================
# Delete above code and rewrite notes     
    
    
    t, deviation = 2, float('inf')
    while deviation >= eps:
        t = t/2
        Ap = make_Ap(A_smp, t)
        m_Ap = model(Ap) # model predicted CoM for Ap
        deviation = torch.norm( m_Ap - m_pq )
        
    t, deviation = 2, float('inf')
    while deviation >= eps:
        t = t/2
        Aq = make_Aq(A_smp, t)
        m_Aq = model(Aq) # model predicted CoM for Aq
        deviation = torch.norm( m_Aq - m_pq )
        
    
    c_Ap = torch.mean(Ap, 1, keepdim=False)    
    c_Aq = torch.mean(Aq, 1, keepdim=False)    
        
    Ap_error = torch.norm( m_Ap - c_Ap ) 
    Aq_error = torch.norm( m_Aq - c_Aq )
    
    # Determine A_argmax, error_max
    errors = [pq_error.item(), Ap_error.item(), Aq_error.item()]
    
    error_argmax = np.argmax( errors ) 
    error_max = errors[ error_argmax ]
    A_argmax = errors[ error_argmax ]
    
    # Compute wiggle room (difference between experiment and theory)
    slack = error_max - lower_bound
    
    # If verbose=True, then print out which case achieves the error and
    # whether there is an inequality violation
    if verbose:
        if bool(error_max >= lower_bound):
            if error_argmax == 0:
                print('\n pq Error: {}* \n Ap Error: {} \n Aq Error: {}'.format(
                        pq_error, Ap_error, Aq_error))
            elif error_argmax == 1:
                print('\n pq Error: {} \n Ap Error: {}* \n Aq Error: {}'.format(
                        pq_error, Ap_error, Aq_error))
            elif error_argmax == 2:
                print('\n pq Error: {} \n Ap Error: {} \n Aq Error: {}*'.format(
                        pq_error, Ap_error, Aq_error))
        else:
            print('\n pq Error: {} \n Ap Error: {} \n Aq Error: {}'.format(
                    pq_error, Ap_error, Aq_error))
            print('* * * VIOLATION OF LOWER BOUND * * *')
        print('\n Max Error:   {} \n Lower Bound: {}'.format(
                error_max, lower_bound))
        print('\n Max Error - Lower Bound: {}'.format(error_max-lower_bound))
     
    elif error_max < lower_bound:
        print('\n pq Error: {} \n Ap Error: {} \n Aq Error: {}'.format(
                    pq_error, Ap_error, Aq_error))
        print('* * * VIOLATION OF LOWER BOUND * * *')
        
    return A_argmax, float(error_max), float(dist_pq), float(slack)

if __name__ == '__main__':
    
    k = 10
    tau = 0.01    #  Need 0 < tau < 1 
    N_tests = 10**4
    
    # Load model
    numclds_in_training = 10**5
    directory = 'numclds_' + str(numclds_in_training)
    model_type = '500-500-500'
    model_epochs = '50'
    batch_size = '64'
    model_path = (directory + '/CoMptnet' + model_type + '_' + model_epochs 
                  + 'epoch' + '_' + batch_size + 'bs.pkl')
    model = torch.load(model_path)
    
    # Run test
    main(model, directory, model_type, k, tau, N_tests)