import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import torch.linalg as tla
import numpy.linalg as la
import numpy as np
import scipy.linalg as sla
import scipy.sparse as sp
import itertools
import sys
sys.path.append('.')
import argparse
import os
import csv 
import pickle 
from pathlib import Path
path = Path(__file__).parent.absolute()
os.chdir(path)
from model.tri_predictor import compute_auc, compute_loss_forex, MLPPredictor_forex, count_parameters, compute_error_curl, compute_error_curl_np, compute_error_curl_interp, compute_error_curl_np_interp,compute_loss_forex_interp
import time
from get_parser import get_parser
'''the implementation of sccnn-node together with ablation study models'''
from model.scnn_model import scnn
from model.mpsn_model import mpsn

def main():
    '''
    hyperparameter
    '''
    parser = get_parser()
    args = parser.parse_args()

    activations = {
        'relu': nn.ReLU(),
        'sigmoid': nn.Sigmoid(),
        'tanh': nn.Tanh(),
        'elu': nn.ELU(),
        'selu': nn.SELU(),
        'leaky_relu' : nn.LeakyReLU(0.01),
        'id': nn.Identity()
    }
    
    device = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))
    print(device)

    b1 = np.loadtxt('./data/B1_FX_1538755200.csv', delimiter=',')
    b2t = np.loadtxt('./data/B2t_FX_1538755200.csv', delimiter=',')
    b2 = b2t.T
    # f contains bid, ask, and mid prices 
    f_init = np.loadtxt('./data/flow_FX_1538755200.csv', delimiter=',')
    num_nodes, num_edges = b1.shape[0], b1.shape[1]
    num_tris = b2.shape[1]
    l1d = b1.T@b1/num_nodes 
    l1u = b2@b2t/num_nodes 
    f_true = l1d@f_init
    I = np.eye(num_edges)
    all_errors = []
    all_curls = []
 
    for rlz_id in range(args.realizations):
        np.random.seed(1337*rlz_id)
        mask = (np.random.random(num_edges) > 0.5).astype(int)
        '''input data for train, val and test'''
        f_train = f_init[:,0]*mask
        f_val = f_init[:,1]*mask
        f_test = f_init[:,2]*mask 
        print('train input--error, curl',compute_error_curl_np_interp(f_train,f_true[:,0],b2,mask))
        print('val input--error, curl',compute_error_curl_np_interp(f_val,f_true[:,1],b2,mask))
        print('test input--error, curl',compute_error_curl_np_interp(f_test,f_true[:,2],b2,mask))
        best_err = 1e6
        for alpha in np.linspace(0,100,11):
            # print(alpha)
            f_est_test = la.inv(I+alpha*l1u)@f_test 
            # print(f_est_test)
            err,_total_curl =  compute_error_curl_np(f_est_test,f_true[:,2],b2)
            if err < best_err:
                 alpha_best = alpha
                 best_err = err
                 best_curl = _total_curl
        print(best_err,best_curl,alpha_best)
        L1l = torch.tensor(l1d,dtype=torch.float32,device=device)
        L1u = torch.tensor(l1u,dtype=torch.float32,device=device)
        L1 = L1l+L1u 
        B1 = torch.tensor(b1,dtype=torch.float32,device=device)
        B2 = torch.tensor(b2,dtype=torch.float32,device=device)
        
        sigma_update = activations[args.activations_update]
        sigma = activations[args.activations]
        sigma_update = sigma
        hidden_features = args.hidden_features
        K = args.filter_order_snn
        K1 = args.filter_order_scnn_k1
        K2 = args.filter_order_scnn_k2
        model_name = args.model_name
        print(model_name)
        F_intermediate = []
        for i in range(args.layers):
            F_intermediate.append(hidden_features)
        F_out = hidden_features
        if (args.mlp_decoder == 'False' and model_name != 'mpsn'):
            F_out = 1
            readout = 'no_readout'
            # args.activations = 'id'
            sigma = activations[args.activations]

        if model_name in ['psnn','snn','scnn']:
            model = scnn(F_in=1, F_intermediate=F_intermediate, F_out=F_out, K=K, K1=K1, K2=K2, laplacian=L1, laplacian_l=L1l, laplacian_u=L1u, sigma=sigma, model_name=model_name)
        if model_name in ['mpsn']:
            model = mpsn(F_in=1, F_intermediate=F_intermediate, F_out=F_out, l1l=L1l, l1u=L1u, agg=args.aggregation,sigma_update=sigma_update, sigma=sigma, model_name=model_name)
        pred = MLPPredictor_forex(h_feats=hidden_features)   
        model.to(device)
        pred.to(device)
        # print(model)
            
        lr = args.learning_rate
        # define the loss and optimizer 
        optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr =lr) 
            
        f_train = torch.tensor(f_train,dtype=torch.float32,device=device)
        f_train.resize_(f_train.shape[0],1)
        f_val = torch.tensor(f_val,dtype=torch.float32,device=device)
        f_val.resize_(f_val.shape[0],1)
        f_test = torch.tensor(f_test,dtype=torch.float32,device=device)
        f_test.resize_(f_test.shape[0],1)
        f_true_train = torch.tensor(f_true[:,0],dtype=torch.float32,device=device)
        f_true_train.resize_(f_true_train.shape[0],1)
        f_true_val = torch.tensor(f_true[:,1],dtype=torch.float32,device=device)
        f_true_val.resize_(f_true_val.shape[0],1)
        f_true_test = torch.tensor(f_true[:,2],dtype=torch.float32,device=device)
        f_true_test.resize_(f_true_test.shape[0],1)
        mask = torch.tensor(mask,dtype=torch.float32,device=device)

        loss_path = r'./loss_files/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations
        if args.mlp_decoder == 'False' and model_name != 'mpsn':
            loss_path = r'./loss_files/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations +'_'+ readout
        if not os.path.exists(loss_path):
            os.makedirs(loss_path)

        if model_name == 'snn':
            losslogf = open(loss_path+"/%s_%dlayers_%dorders_%dfeatures_%drlz.txt" %(model_name,args.layers,K,hidden_features,rlz_id),"w")
        elif model_name == 'scnn':
            losslogf = open(loss_path+"/%s_%dlayers_%d_%dorders_%dfeatures_%drlz.txt" %(model_name,args.layers,K1,K2,hidden_features,rlz_id),"w")
        elif model_name == 'psnn':
            losslogf = open(loss_path+"/%s_%dlayers_%dfeatures_%drlz.txt" %(model_name,args.layers,hidden_features,rlz_id),"w")
        elif model_name in ['mpsn']:
            losslogf = open(loss_path+"/%s_%s_%dlayers_%dfeatures_%drlz.txt" %(args.aggregation,model_name,args.layers,hidden_features,rlz_id),"w")
            
        best_val = 1e6
        '''
        training
        '''
        # save the model
        model_path = r'./model_nn/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations
        affix = '%s_%drlz'%(args.activations,rlz_id)
        if args.mlp_decoder == 'False' and model_name != 'mpsn':
            model_path =  r'./model_nn/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations  +'_'+ readout
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        if model_name == 'snn':
            path_affix = '/%dlayers_%dorders_%dfeatures'%(args.layers,K,hidden_features)     
        elif model_name == 'scnn':
            path_affix = '/%dlayers_%d_%dorders_%dfeatures'%(args.layers,K1,K2,hidden_features) 
        elif model_name == 'psnn':
            path_affix = '/%dlayers_%dfeatures'%(args.layers,hidden_features)
        elif model_name == 'mpsn':
            path_affix = '/%s_%dlayers_%dfeatures'%(args.aggregation,args.layers,hidden_features)

        model_save_path = model_path+path_affix+affix
        
        for e in range(args.epochs):
            if args.mlp_decoder == 'False' and model_name != 'mpsn':
                f_hat_train = model(f_train)
            else:    
                h = model(f_train)
                f_hat_train = pred(h)
                
            train_loss = compute_loss_forex_interp(f_hat_train,f_true_train,mask,B2)
            train_error,train_curl = compute_error_curl_interp(f_hat_train,f_true_train,B2,mask)
            
            losslogf.write("epoch %d, loss: %f, curl: %f\n" %(e, train_loss.item(),tla.norm(B2.T@f_hat_train)))
            losslogf.flush()
            
            '''val'''
            with torch.no_grad():
                if args.mlp_decoder == 'False' and model_name != 'mpsn':
                    f_hat_val = model(f_val)
                else:
                    h = model(f_val)
                    f_hat_val = pred(h)
                vloss = compute_loss_forex_interp(f_hat_val,f_true_val,mask,B2)
                verror,vcurl = compute_error_curl_interp(f_hat_val,f_true_val,B2,mask) 
                if vloss < best_val:
                    best_val = vloss
                    print('curent epoch:', e)
                    torch.save({'model': model.state_dict(),'pred':pred.state_dict()},model_save_path)
                    losslogf.write("model updated at epoch %d \n" %(e))
                print('epoch {},\n train loss: {}, val loss: {}, train error, {}, v_error: {}, train curls: {}, val curls: {}'.format(e, train_loss, vloss, train_error, verror,train_curl,vcurl))
    
            losslogf.write("epoch %d, \n train loss: %f, val loss: %f, train error %f, val error %f, train curl %f, val curl %f \n  " %(e, train_loss.item(), vloss, train_error, verror, train_curl, vcurl))
            losslogf.flush()
            '''
            backward
            '''
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()  
            
        '''
        testing
        '''       
        model.load_state_dict(torch.load(model_save_path,map_location=device)['model'],strict=False)
        pred.load_state_dict(torch.load(model_save_path,map_location=device)['pred'],strict=False)
        
        with torch.no_grad():     
            if args.mlp_decoder == 'False' and model_name != 'mpsn': 
                f_hat_test = model(f_test) 
            else:
                h = model(f_test) 
                f_hat_test = pred(h)
            print(f_hat_test)
            test_error,test_curl = compute_error_curl_interp(f_hat_test,f_true_test,B2,mask)        
            
            losslogf.write("test results, \n test error: %f, test curl %f \n  " %(test_error, test_curl))
            losslogf.flush()
            print('val error, test error',verror,test_error)  
            losslogf.close()
            print('v_curl: ',vcurl,'test curls: ',test_curl)   
        all_errors.append(test_error.detach().cpu().numpy())
        all_curls.append(test_curl.detach().cpu().numpy())    
    '''
    writing the average performance over 10 runs
    '''

    if args.model_name == 'mpsn':
        finalresults_path = './final_results/'+args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations+'_'+ args.aggregation
    elif args.mlp_decoder == 'False' and model_name != 'mpsn': 
        finalresults_path = './final_results/'+args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations +'_'+ readout 
    else:
        finalresults_path = './final_results/'+args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations
    if not os.path.exists(finalresults_path):
        os.makedirs(finalresults_path)
    with open(finalresults_path+path_affix+'.pkl','wb') as f_results:
        pickle.dump([all_errors,all_curls],f_results)   
            
if __name__ == "__main__":
    main()