import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import numpy.linalg as la
import numpy as np
import scipy.linalg as sla
import torch.linalg as tla
import scipy.sparse as sp
import itertools
import sys
import pickle 
sys.path.append('.')
import argparse
import os
import csv 
from pathlib import Path
path = Path(__file__).parent.absolute()
os.chdir(path)
from model.tri_predictor import compute_auc, compute_loss_forex, MLPPredictor_forex, count_parameters, compute_error_curl, compute_error_curl_np
import time
from get_parser import get_parser
 

def main():
    '''
    hyperparameter
    '''
    parser = get_parser()
    args = parser.parse_args()
    args.noise_type = 'curl'
    b1 = np.loadtxt('./data/B1_FX_1538755200.csv', delimiter=',')
    b2t = np.loadtxt('./data/B2t_FX_1538755200.csv', delimiter=',')
    b2 = b2t.T
    # f contains bid, ask, and mid prices 
    f_init = np.loadtxt('./data/flow_FX_1538755200.csv', delimiter=',')
    num_nodes, num_edges = b1.shape[0], b1.shape[1]
    num_tris = b2.shape[1]
    # print(f[:,0])
    l1d = b1.T@b1/num_nodes 
    l1u = b2@b2t/num_nodes 
    f_true = l1d@f_init
    # print(f_true.shape)
    # print(f_true)
    all_errors = []
    all_curls = []
    all_input_errors = []
    all_input_curls = []
    I = np.eye(num_edges)
    true_err,true_curl = compute_error_curl_np(f_init[:,0],f_true[:,0],b2)
    print(true_err,true_curl)
    for rlz_id in range(args.realizations):
        torch.manual_seed(1337*rlz_id)
        mask = np.ones((num_edges,1), dtype=int)

        snr_db = args.snr
        snr = 10**(snr_db/10)
        power_flow = la.norm(f_init[:,1],2)
        power_noise = power_flow/snr/num_edges
        if args.noise_type == 'random':
            noise = power_noise*np.random.normal(0,1,size=(num_edges,))
        elif args.noise_type == 'curl':
            noise_tri = power_noise*np.random.normal(0,1,size=(num_tris,))
            noise = b2@noise_tri
        
        print(la.norm(noise))
        print(f_init.shape,f_init[:,1].shape,noise.shape)
        f_train = f_init[:,0] + noise
        f_val = f_init[:,1] + noise 
        f_test = f_init[:,2] + noise
        print(f_val.shape )
        print('train input--error, curl',compute_error_curl_np(f_train,f_true[:,0],b2))
        print('val input--error, curl',compute_error_curl_np(f_val,f_true[:,1],b2))
        print('test input--error, curl',compute_error_curl_np(f_test,f_true[:,2],b2))
        best_err = 1e6
        for alpha in np.linspace(0,10,11):
            # print(alpha)
            f_est_test = la.inv(I+alpha*l1u)@f_test 
            # print(f_est_test)
            err,_total_curl =  compute_error_curl_np(f_est_test,f_true[:,2],b2)
            if err < best_err:
                 alpha_best = alpha
                 best_err = err
                 best_curl = _total_curl
        
        print(best_err,best_curl,alpha_best)
        all_errors.append(best_err)
        all_curls.append(best_curl)
        a,b = compute_error_curl_np(f_test,f_true[:,2],b2)
        all_input_errors.append(a)
        all_input_curls.append(b)
    print('inputs:')
    print(np.mean(all_input_errors),np.std(all_input_errors))
    print(np.mean(all_input_curls),np.std(all_input_curls))
    print('outputs:')
    print(np.mean(all_errors),np.std(all_errors))
    print(np.mean(all_curls),np.std(all_curls))




if __name__ == "__main__":
    main()