import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
from utils import *
import os
from datetime import datetime

def run_sweep_n(sweep_type, sweep_range, params, filename):

    #Set random seed
    np.random.seed(0)

    #Get number of trials and max iterations for solver
    MC = params['MC']
    max_iter = params['max_iter']

    #Get values for other parameters if not sweeping particulat parameter
    if sweep_type == "d":
        r, y, eta_up = params['r'], params['y'], params['eta_up']
        x_range = params['n_v']
    elif sweep_type == "r":
        d, y, eta_up = params['d'], params['y'], params['eta_up']
        x_range = params['n_v']
    elif sweep_type == "m":
        y, eta_up = params['y'], params['eta_up']
        x_range = params['m_v']


    #Output dict for experiments
    norm_err_sweep_all = {}

    #Sweep over parameter range, get value of sweep parameter
    for sweep_val in sweep_range:
        if sweep_type == 'd':
            d = sweep_val
        elif sweep_type == 'r':
            r = sweep_val
        elif sweep_type == 'm':
            N, d = sweep_val[0], sweep_val[1]

            if N/d == 1000:
                x_range = params['m_v'][0]
            elif N/d == 400:
                x_range = params['m_v'][1]

        if sweep_type in ['r', 'd']:
            print('%s = %0d -----------------------------------' % (sweep_type, sweep_val) )
        elif sweep_type == 'm':
            print('%s, N = %0d, d = %0d -----------------------' % (sweep_type, sweep_val[0], sweep_val[1]))


        #Error for fixed value of sweep parameter
        norm_err_sweep = []

        #Sweep over values of n
        for x in x_range:
            if sweep_type in ['r', 'd']:
                n = x
                print("n = %0d, N = %0d --------------------------" % (n, N))

                noise_var = (eta_up ** 2) / 3 #uniform noise variance
                m = int(np.floor(np.sqrt(n / d)) * noise_var) #theoretically suggested value for m in terms of n
                N = int(m*n)
            elif sweep_type == 'm':
                m = x
                print('m = %0d -----------------------------------' % (m) )

                n = int(N / m)

            #store error for each trial
            norm_err_mc = []


            #Repeat over MC trials
            for mc in range(MC):
                if (mc + 1) % 5 == 0:
                    print("mc = %0d / %0d" % (mc+1, MC))

                #Generate ground truth metric Sig and tau
                Sig, tau = get_gt_orth(d, r, y, eta_up, n)

                #Use Sig and tau to get sensing matrices As and responses gammas
                As, gammas = get_meas(d, m, n, tau, Sig, y = y, eta_up = eta_up)
                A_flat = As.reshape(d*d, -1)

                #Get regularization parameter lda
                lda = get_lda(d, n, m, r, 1) #Constant C = 1

                #Estimate metric from sensing matrices, lda, etc.
                Sig_est = get_est(A_flat, y, lda, d, n, max_iter=max_iter)

                #Store normalized error for this trial
                norm_err_mc.append(np.linalg.norm(Sig - Sig_est, 'fro') / np.linalg.norm(Sig, 'fro'))

            print('Error mean, std = %0f, %0f' % (np.mean(norm_err_mc), np.std(norm_err_mc)))

            #store value of all sweeps
            norm_err_sweep.append(norm_err_mc)

        #Store sweeps for all values of mc, n for fixed sweep parameter value
        norm_err_sweep_all[sweep_val] = norm_err_sweep

    #Save output
    save_exp(norm_err_sweep_all, n_v, sweep_type, params, filename)




sweep_types = ['d', 'r', 'm']
sweep_vals = {
    'd' : [40, 50, 60],
    'r' : [5, 8, 9, 10, 15, 20],
    'm' : [(50000, 50), (40000, 40), (20000, 50), (16000, 40)] # (N, d) combination
}

params = {
    'r' : 9,
    'y' : 200,
    'eta_up' : 10,
    'd' : 50,
    'MC' : 20,
    'n_v' : np.arange(500, 5001, 100),
    'm_v' : ([1, 2, 5, 10, 20, 25, 40, 50], [1, 2, 5, 8, 10, 20, 25, 50]), #[0] is for N/d = 1000, [1] is for N/d = 400
    'max_iter' : 50000
}


exp_dir = os.path.join("./exp_out", datetime.today().strftime('%Y-%m-%d'))
if not os.path.exists(exp_dir):
    os.makedirs(exp_dir)

for sweep_type in sweep_types:
    print('Running sweep for %s ====================================' % sweep_type)
    sweep_range = sweep_vals[sweep_type]

    filename = os.path.join(exp_dir, "error_vary_" + sweep_type)

    run_sweep_n(sweep_type, sweep_range, params, filename)
