from datetime import datetime

from concurrent import futures
from copy import deepcopy

from scipy.stats import truncnorm
import numpy as np
import pandas as pd

from models.ERM import ERM
from models.Gradient_Based import GP_Moreau, GP_Phased_ERM, GP_NSGD
from models.utils_AOP import Evaluator



if __name__ == '__main__':
    for varepsilon, delta in zip([0.2, 0.3, 0.5], [0.05] * 3):
        # -- exp params
        x_true = np.asarray([0.5, -0.5, 1, -1, 1])
        D = 5
        A_mean = np.asarray([[1, 0.5, 0, 1, 1],
                             [0.5, 0.5, 0, 1, 1],
                             [0, 0, -0.5, 1, 1]])
        A_mean = np.column_stack([A_mean, A_mean])
        x_true = np.concatenate([x_true, x_true])
        m, d = A_mean.shape
        A_oper_norm_ub = 3 * A_mean.shape[1] / 5

        exp_param = {'n': None, 'm': m, 'd': d,
                      'x_true': x_true, 'A_mean': A_mean, 'A_oper_norm_ub': A_oper_norm_ub, 'D': D,
                      'loss_func': 'l1',  # options: l1, piecewise
                      'pieces': np.asarray([[1, 1, 1, 0],
                                            [1, 1, -1, 0],
                                            [1, -1, 1, 0],
                                            [-1, 1, 1, 0],
                                            [1, -1, -1, 0]])}
        loss_func = exp_param['loss_func']
        x_star = ERM(exp_param).Solve_x(30000)
        exp_param['x_star'] = x_star
        cost_optimal = Evaluator(exp_param, 30000).Evaluate(exp_param['x_star'])['cost']

        # -- algo params
        algo_param = {'varepsilon': varepsilon, 'delta': delta,
                      'grad_type': 'nsgd'}  # options: 'nsgd', 'moreau', 'imprved_moreau'

        # -- construct tasks
        for algo in ['moreau']:
            algo_param.update({'grad_type': algo})
            grad_type = algo_param['grad_type']

            ns = [1000, 500, 400, 300, 200, 150, 100, 50]
            iter = 108
            tasks = []
            for n in ns:
                for iteration in range(iter):
                    new_exp_param = deepcopy(exp_param)
                    new_exp_param.update({'n': n})
                    # -- algo params

                    if algo_param['grad_type'] == 'nsgd':
                        tasks.append(GP_NSGD(new_exp_param, algo_param))
                    elif algo_param['grad_type'] in ['moreau', 'imprved_moreau']:
                        tasks.append(GP_Moreau(new_exp_param, algo_param))

            print(f'----in total, {len(ns)*iter} experiments----')
            t = datetime.now()
            outputs = []
            try:
                with futures.ProcessPoolExecutor(max_workers=12) as executor:
                    tasks = [executor.submit(tasks[i].Run) for i, _ in enumerate(tasks)]
                    for k, task in enumerate(futures.as_completed(tasks)):
                        task_return = task.result()
                        outputs.append(task_return)
                        print(f'----{k}----' + 'Done----')
            except Exception as error:
                print(error)

            # ---- write output
            try:
                time = datetime.now().strftime('%Y%m%d%H%M%S')
                try:
                    output_df = pd.DataFrame.from_records(outputs)
                    output_df['cost_optimal'] = cost_optimal
                    output_df['cost_gap%'] = (output_df['cost'] - cost_optimal) / cost_optimal * 100
                    output_df['l2_dist_gap%'] = (output_df['l2_dist']) / np.linalg.norm(x_star, 2) * 100
                    output_df['method'] = grad_type.upper()
                    output_df['varepsilon'], output_df['delta'] = varepsilon, delta
                    output_df.to_csv(
                        f'./data/{grad_type.upper()}_m={m}_d={d}_(varepsilon,delta)=({varepsilon},{delta})_f={loss_func}_{time}.csv',
                        index=False)
                    print(datetime.now() - t)
                except:
                    pd.DataFrame.from_records(outputs).to_csv(f'./data/default_name.csv', index=False)
            except Exception as error:
                print(error)






