import dq_sync as dqs
import numpy as np
from scipy.linalg import norm
from scipy.stats import trim_mean
import time
import csv
import argparse


# parameter setting
sigma_r_list = np.linspace(0,20,5)
no_trial = 100
maxiter = 500
tol_gpm = 1e-5
delta_gpm = False

n = 100
p_list = [0.05,0.08,0.3]



parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=12345)
args = parser.parse_args()


def synthe_exp(size = 1, sigma_r = 1, sigma_t = 1, p = 1, maxiter_gpm = 100, tol_gpm = 0.01, delta_gpm = True):

    origin = dqs.generateGroundTruth(size)

    r_ori, t_ori = dqs.dq2rt(dqs.dqmat2dq(origin))

    # C, C_arri = dqs.generateObservations_multi(origin, sigma_r=dqs.angle2radians(sigma_r), sigma_t=sigma_t, p=0.05, q=1)
    C_gpm, C_arri = dqs.generateObservations_addi(origin, sigma_r=dqs.angle2radians(sigma_r), sigma_t=sigma_t, p=p)


    # C_rosen = C_arri
    # savemat('sdr_rosen\examples\C_rosen.mat', {'C_rosen': C_rosen})


    # do experiment for gpm method
    shape = (size, 1)
    out = dqs.mb2bm(dqs.dq2dqmat(dqs.rt2dq(*dqs.randomUniformGaussian(shape))))

    time_gpm = 0
    descent_t_curr = 0

    for i in range(maxiter_gpm):
        descent_t_pre = descent_t_curr
        out_pre = dqs.normalize(dqs.bm2mb(out))
        r_pre,t_pre = dqs.dq2rt(dqs.dqmat2dq(out_pre))

        t1 = time.time()
        out = C_gpm @ out
        out_mb = dqs.normalize(dqs.bm2mb(out))
        out = dqs.mb2bm(out_mb)
        time_gpm = time_gpm + time.time()-t1

        out_curr = dqs.normalize(dqs.bm2mb(out))
        r_curr, t_curr = dqs.dq2rt(dqs.dqmat2dq(out_curr))
        descent_t_curr = norm(t_pre - t_curr)
        descent_t_delta = abs(descent_t_curr - descent_t_pre)
        if delta_gpm:
            if descent_t_delta < tol_gpm:
                print(i)
                break

        else:
            if descent_t_curr < tol_gpm:
                print(i)
                break




    estimate_gpm = dqs.normalize(dqs.bm2mb(out))
    rerror_gpm, terror_gpm = dqs.calculateError(estimate_gpm, origin)
    rerrors_gpm = np.mean(rerror_gpm)
    terrors_gpm = np.mean(terror_gpm)



    # # do experiment for sdr_rosen method
    # eng = matlab.engine.start_matlab()
    # sdr_rosen_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'sdr_rosen'))
    # eng.cd(sdr_rosen_path, nargout=0)
    #
    #
    # with open(os.devnull, 'w') as fnull:
    #     fnull = io.StringIO()
    #     eng.dealpath(nargout=0, stdout=fnull)
    #     time_rosen = eng.pointset_rosen(nargout=1, stdout=fnull)
    #     time_rosen = eng.pointset_rosen(nargout=1, stdout=fnull)
    #
    # eng.quit()
    #
    #
    # data_rosen_r = loadmat('r_rosen.mat')
    # data_rosen_t = loadmat('t_rosen.mat')
    # r_rosen = data_rosen_r['r_rosen'].reshape(n,1,4)
    # t_rosen = data_rosen_t['t_rosen'].reshape(n,1,3)
    # estimate_rosen = dqs.dq2dqmat(dqs.rt2dq(r_rosen,t_rosen))
    # rerror_rosen,terror_rosen = dqs.calculateError(estimate_rosen, origin)
    # rerrors_rosen = np.mean(rerror_rosen)
    # terrors_rosen = np.mean(terror_rosen)



    # do experiment for arrigoni method

    t2 = time.time()
    x_arri = dqs.arrigoni(C_arri)
    x_arri = dqs.bm2mb(x_arri)
    estimate_arri = dqs.dq2dqmat(dqs.rt2dq(*dqs.mat2rt(x_arri)))
    time_arri = time.time()-t2

    rerror_arri, terror_arri = dqs.calculateError(estimate_arri, origin)
    rerrors_arri = np.mean(rerror_arri)
    terrors_arri = np.mean(terror_arri)


    return time_gpm, time_arri, rerrors_gpm, rerrors_arri, terrors_gpm, terrors_arri



for p in p_list:
    # results saving
    all_results_filename = 'syntResult/all_results_rtboth_p{:03d}_tol{}.csv'.format(
        int(p*100), int(-np.log10(tol_gpm)))

    stats_results_filename = 'syntResult/stats_results_rtboth_p{:03d}_tol{}.csv'.format(
        int(p*100), int(-np.log10(tol_gpm)))



    with (open(stats_results_filename, 'w', newline='') as f_stats, \
         open(all_results_filename, 'w', newline='') as f):

        writer_stats = csv.writer(f_stats)
        writer = csv.writer(f)



        writer.writerow(['sigma_r', 'sigma_t', 'trial', 'time_gpm', 'time_arri',
                         'rerrors_gpm', 'rerrors_arri',
                         'terrors_gpm', 'terrors_arri'])

        f.flush()


        col_names = ['time_gpm', 'time_arri',
                         'rerrors_gpm', 'rerrors_arri',
                         'terrors_gpm', 'terrors_arri']

        writer_stats.writerow(
            ['sigma_r', 'sigma_t'] +
            [f'mean_{col}' for col in col_names] +
            [f'std_{col}' for col in col_names] +
            [f'median_{col}' for col in col_names] +
            [f'trim10_{col}' for col in col_names] +
            [f'trim15_{col}' for col in col_names]
        )

        f_stats.flush()

        # for sigma_r in sigma_r_list:
        for sigma_r in sigma_r_list:
            sigma_t = 0.01*sigma_r
            trial_results = []
            for i in range(no_trial):
                trial_seed = args.seed + i
                np.random.seed(trial_seed)

                time_gpm, time_arri, rerrors_gpm, rerrors_arri, terrors_gpm,  terrors_arri = \
                synthe_exp( size = n, sigma_r = dqs.angle2radians(sigma_r), sigma_t = sigma_t, p = p, maxiter_gpm = maxiter,tol_gpm = tol_gpm, delta_gpm = True)

                # Write each trial result immediately
                writer.writerow([
                    sigma_r, sigma_t, i, time_gpm,time_arri,
                    rerrors_gpm, rerrors_arri, terrors_gpm, terrors_arri
                ])

                f.flush()

                trial_results.append([
                    sigma_r, sigma_t, i, time_gpm, time_arri, rerrors_gpm, rerrors_arri,
                    terrors_gpm, terrors_arri
                ])

                print('p = {:3}'.format(p), 'finish trial {:2} of sigma_r {:<1} \t sigma_t:{:1}'.format(i, sigma_r, sigma_t))


            trial_results_np = np.array(trial_results)[:, 3:].astype(float)

            means = trial_results_np.mean(axis=0)
            stds = trial_results_np.std(axis=0)
            medians = np.median(trial_results_np, axis=0)
            trim1_means = [trim_mean(trial_results_np[:, i], proportiontocut=0.10) for i in
                           range(trial_results_np.shape[1])]
            trim2_means = [trim_mean(trial_results_np[:, i], proportiontocut=0.15) for i in
                           range(trial_results_np.shape[1])]

            result_row = [
                sigma_r, sigma_t,
                *means,
                *stds,
                *medians,
                *trim1_means,
                *trim2_means
            ]

            writer_stats.writerow(result_row)
            f_stats.flush()








