import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import numpy.linalg as la
import numpy as np
import scipy.linalg as sla
import torch.linalg as tla
import scipy.sparse as sp
import itertools
import sys
import pickle 
sys.path.append('.')
import argparse
import os
import csv 
from pathlib import Path
path = Path(__file__).parent.absolute()
os.chdir(path)
from model.tri_predictor import compute_auc, compute_loss_forex, MLPPredictor_forex, count_parameters, compute_error_curl, compute_error_curl_np
import time
from get_parser import get_parser


def main():
    '''
    hyperparameter
    '''
    parser = get_parser()
    args = parser.parse_args()
    args.noise_type = 'curl'
    b1 = np.loadtxt('./data/B1_FX_1538755200.csv', delimiter=',')
    b2t = np.loadtxt('./data/B2t_FX_1538755200.csv', delimiter=',')
    b2 = b2t.T
    # f contains bid, ask, and mid prices 
    f_init = np.loadtxt('./data/flow_FX_1538755200.csv', delimiter=',')
    num_nodes, num_edges = b1.shape[0], b1.shape[1]
    num_tris = b2.shape[1]
    # print(f[:,0])
    l1d = b1.T@b1/num_nodes 
    l1u = b2@b2t/num_nodes 
    f_true = l1d@f_init
    # print(f_true.shape)
    # print(f_true)
    all_errors = []
    all_curls = []
    all_input_errors = []
    all_input_curls = []
    I = np.eye(num_edges)
    true_err,true_curl = compute_error_curl_np(f_true[:,2],f_true[:,2],b2)
    print(true_err,true_curl)
    for rlz_id in range(args.realizations):
        np.random.seed(1337*rlz_id)
        mask = (np.random.random(num_edges) > 0.5).astype(int)
        '''input data for train, val and test'''
        f_train = f_init[:,0]*mask
        f_val = f_init[:,1]*mask
        f_test = f_init[:,2]*mask 

        print(f_val.shape )
        print('train input--error, curl',compute_error_curl_np(f_train,f_true[:,0],b2))
        print('val input--error, curl',compute_error_curl_np(f_val,f_true[:,1],b2))
        print('test input--error, curl',compute_error_curl_np(f_test,f_true[:,2],b2))
        best_err = 1e6
        for alpha in np.linspace(0,10,11):
            # print(alpha)
            f_est_test = la.inv(I+alpha*l1u)@f_test 
            # print(f_est_test)
            err,_total_curl =  compute_error_curl_np(f_est_test,f_true[:,2],b2)
            if err < best_err:
                 alpha_best = alpha
                 best_err = err
                 best_curl = _total_curl
        
        print(best_err,best_curl,alpha_best)
        
        all_errors.append(best_err)
        all_curls.append(best_curl)
        a,b = compute_error_curl_np(f_test,f_true[:,2],b2)
        all_input_errors.append(a)
        all_input_curls.append(b)
        
    print('inputs:')
    print(np.mean(all_input_errors),np.std(all_input_errors))
    print(np.mean(all_input_curls),np.std(all_input_curls))
    print('outputs:')
    print(np.mean(all_errors),np.std(all_errors))
    print(np.mean(all_curls),np.std(all_curls))



if __name__ == "__main__":
    main()