import dq_sync as dqs
import numpy as np
from scipy.io import loadmat
from scipy.io import savemat
from scipy.linalg import norm
import time
import matlab.engine
import os
import io
import csv
import argparse



## choose dataset, 'bunny' has only 'sparse' option

# shape_option = 'bunny'
# shape_option = 'happy'
# shape_option = 'dragon'
shape_option = 'arma'

# density_option = 'sparse'
density_option = 'dense'


# parameter setting
rlevel_list = [1]
tlevel_list = [0.001 * x for x in rlevel_list]
no_trial = 100
maxiter = 100


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=12345)
args = parser.parse_args()


mat_data = loadmat(f'icpData/meapyMat_{shape_option}_{density_option}.mat')
rawdata = mat_data['meapyMat_q_refine']



t = rawdata[:,:,0:3]
r = rawdata[:,:,3:]
n = np.shape(t)[0]

C_partial = dqs.rt2dq(r, t)
C_partial = dqs.dq2dqmat(C_partial)
C = dqs._hermitizeDQmat(C_partial)



#load groundtruth
gt_data = loadmat(f'gtData/gtVec_{shape_option}.mat')
raw_data = gt_data['gtVec2']
r_ori = np.zeros((n,1,4))
t_ori = np.zeros((n,1,3))

for i in range(1, n+1):
    r_ori[i-1,0,0] = raw_data[i-1,3]
    r_ori[i-1,0,1:] = raw_data[i-1,4:]

for i in range(1, n+1):
    t_ori[i-1,0,:] = raw_data[i-1,0:3]

origin  = dqs.dq2dqmat(dqs.rt2dq(r_ori,t_ori))



def pointset_exp(C, origin, rlevel = 1, tlevel = 1, maxiter_gpm = 100, gaussianRotation = True, uniformRotation = False):

    # Apply noise to measurements and hermitize the resulting matrix
    assert (gaussianRotation * uniformRotation == 0) & (gaussianRotation + uniformRotation == 1)
    if gaussianRotation:
        noise = dqs.dq2dqmat(dqs.rt2dq(*dqs.randomGaussianGaussian((n, n), sigma_r=rlevel, sigma_t=tlevel)))
    else:
        noise = dqs.dq2dqmat(dqs.rt2dq(*dqs.randomUniformGaussianNoise((n, n), uprange = rlevel, sigma_t = tlevel)))
    inds = np.diag_indices(n)
    noise[inds[0], inds[1], :, :] = np.eye(4)

    C = C @ noise
    C = dqs._hermitizeDQmat(C)

    C_arri = dqs.mb2bm(dqs.rt2mat(*dqs.dq2rt(dqs.dqmat2dq(C))))
    C_gpm = dqs.mb2bm(C)
    C_rosen = C_arri

    savemat('sdr_rosen\examples\C_rosen.mat', {'C_rosen': C_rosen})



    # do experiment for gpm method
    shape = (C_gpm.shape[0]//4, 1)
    out = dqs.mb2bm(dqs.dq2dqmat(dqs.rt2dq(*dqs.randomUniformGaussian(shape))))

    time_gpm = 0

    for i in range(maxiter_gpm):
        out_pre = dqs.normalize(dqs.bm2mb(out))
        r_pre,t_pre = dqs.dq2rt(dqs.dqmat2dq(out_pre))

        t1 = time.time()
        out = C_gpm @ out
        out_mb = dqs.normalize(dqs.bm2mb(out))
        out = dqs.mb2bm(out_mb)
        time_gpm = time_gpm + time.time()-t1

        out_curr = dqs.normalize(dqs.bm2mb(out))
        r_curr, t_curr = dqs.dq2rt(dqs.dqmat2dq(out_curr))
        descent_t = norm(t_pre - t_curr)

        if descent_t < 1e-4:
            break


    estimate_gpm = dqs.normalize(dqs.bm2mb(out))
    rerror_gpm, terror_gpm = dqs.calculateError(estimate_gpm, origin)
    rerrors_gpm = np.mean(rerror_gpm)
    terrors_gpm = np.mean(terror_gpm)



    # do experiment for sdr_rosen method
    eng = matlab.engine.start_matlab()
    sdr_rosen_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'sdr_rosen'))
    eng.cd(sdr_rosen_path, nargout=0)


    with open(os.devnull, 'w') as fnull:
        fnull = io.StringIO()
        eng.dealpath(nargout=0, stdout=fnull)
        time_rosen = eng.pointset_rosen(nargout=1, stdout=fnull)

    eng.quit()


    data_rosen_r = loadmat('sdr_rosen/examples/r_rosen.mat')
    data_rosen_t = loadmat('sdr_rosen/examples/t_rosen.mat')
    r_rosen = data_rosen_r['r_rosen'].reshape(n,1,4)
    t_rosen = data_rosen_t['t_rosen'].reshape(n,1,3)
    estimate_rosen = dqs.dq2dqmat(dqs.rt2dq(r_rosen,t_rosen))
    rerror_rosen,terror_rosen = dqs.calculateError(estimate_rosen, origin)
    rerrors_rosen = np.mean(rerror_rosen)
    terrors_rosen = np.mean(terror_rosen)



    # do experiment for arrigoni method

    t2 = time.time()
    x_arri = dqs.arrigoni(C_arri)
    x_arri = dqs.bm2mb(x_arri)
    estimate_arri = dqs.dq2dqmat(dqs.rt2dq(*dqs.mat2rt(x_arri)))
    time_arri = time.time()-t2

    rerror_arri, terror_arri = dqs.calculateError(estimate_arri, origin)
    rerrors_arri = np.mean(rerror_arri)
    terrors_arri = np.mean(terror_arri)


    return time_gpm, time_rosen, time_arri, rerrors_gpm, rerrors_rosen, rerrors_arri, terrors_gpm, terrors_rosen, terrors_arri



# results saving
all_results_filename = f'realResult/all_results_{shape_option}_{density_option}.csv'
average_results_filename = f'realResult/average_results_{shape_option}_{density_option}.csv'



with open(average_results_filename, 'w', newline='') as f_avg, \
     open(all_results_filename, 'w', newline='') as f:

    writer_avg = csv.writer(f_avg)
    writer = csv.writer(f)


    writer_avg.writerow(['rlevel', 'tlevel',
                         'mean_time_gpm', 'mean_time_rosen', 'mean_time_arri',
                         'mean_rerrors_gpm', 'mean_rerrors_rosen', 'mean_rerrors_arri',
                         'mean_terrors_gpm', 'mean_terrors_rosen', 'mean_terrors_arri'])

    f_avg.flush()


    writer.writerow(['rlevel', 'tlevel', 'trial', 'time_gpm', 'time_rosen', 'time_arri',
                     'rerrors_gpm', 'rerrors_rosen', 'rerrors_arri',
                     'terrors_gpm', 'terrors_rosen', 'terrors_arri'])

    f.flush()

    for rlevel in rlevel_list:
        for tlevel in tlevel_list:
            trial_results = []
            for i in range(no_trial):

                trial_seed = args.seed + i
                np.random.seed(trial_seed)

                time_gpm, time_rosen, time_arri, rerrors_gpm, rerrors_rosen, rerrors_arri, terrors_gpm, terrors_rosen, terrors_arri = \
                pointset_exp(C, origin, rlevel = dqs.angle2radians(rlevel), tlevel = tlevel, maxiter_gpm = maxiter, gaussianRotation = True, uniformRotation = False)

                # Write each trial result immediately
                writer.writerow([
                    rlevel, tlevel, i, time_gpm, time_rosen, time_arri,
                    rerrors_gpm, rerrors_rosen, rerrors_arri, terrors_gpm, terrors_rosen, terrors_arri
                ])

                f.flush()

                trial_results.append([
                    rlevel, tlevel, i, time_gpm, time_rosen, time_arri, rerrors_gpm, rerrors_rosen, rerrors_arri,
                    terrors_gpm, terrors_rosen, terrors_arri
                ])

                print('finish trial {:1} of rlevel {:<1} \t tlevel:{:1}'.format(i, rlevel, tlevel))


        trial_results_np = np.array(trial_results)[:, 3:].astype(float)
        means = trial_results_np.mean(axis=0)
        writer_avg.writerow([rlevel, tlevel, *means])
        f_avg.flush()







