import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import numpy.linalg as la
import numpy as np
import scipy.linalg as sla
import torch.linalg as tla
import scipy.sparse as sp
import itertools
import sys
import pickle 
sys.path.append('.')
import argparse
import os
import csv 
from pathlib import Path
path = Path(__file__).parent.absolute()
os.chdir(path)
from model.tri_predictor import compute_auc, compute_loss_forex, MLPPredictor_forex, count_parameters, compute_error_curl, compute_error_curl_np
import time
from get_parser import get_parser

'''the implementation of sccnn-node together with ablation study models'''
from model.scnn_model import scnn
from model.mpsn_model import mpsn

def main():
    '''
    hyperparameter
    '''
    parser = get_parser()
    args = parser.parse_args()

    activations = {
        'relu': nn.ReLU(),
        'sigmoid': nn.Sigmoid(),
        'tanh': nn.Tanh(),
        'elu': nn.ELU(),
        'selu': nn.SELU(),
        'leaky_relu' : nn.LeakyReLU(0.01),
        'id': nn.Identity()
    }
    
    device = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))
    print(device)

    b1 = np.loadtxt('./data/B1_FX_1538755200.csv', delimiter=',')
    b2t = np.loadtxt('./data/B2t_FX_1538755200.csv', delimiter=',')
    b2 = b2t.T
    # f contains bid, ask, and mid prices 
    f_init = np.loadtxt('./data/flow_FX_1538755200.csv', delimiter=',')
    num_nodes, num_edges = b1.shape[0], b1.shape[1]
    num_tris = b2.shape[1]
    # print(f[:,0])
    l1d = b1.T@b1/num_nodes 
    l1u = b2@b2t/num_nodes 
    f_true = l1d@f_init
    # print(f_true.shape)
    # print(f_true)
    all_errors = []
    all_curls = []
    '''generate training data'''  
    for rlz_id in range(args.realizations):
        torch.manual_seed(1337*rlz_id)
        mask = np.ones((num_edges,1), dtype=int)

        snr_db = args.snr
        snr = 10**(snr_db/10)
        power_flow = la.norm(f_init[:,1],2)
        power_noise = power_flow/snr/num_edges
        if args.noise_type == 'random':
            noise = power_noise*np.random.normal(0,1,size=(num_edges,))
        elif args.noise_type == 'curl':
            noise_tri = power_noise*np.random.normal(0,1,size=(num_tris,))
            noise = b2@noise_tri
        
        print(la.norm(noise))
        print(f_init.shape,f_init[:,1].shape,noise.shape)
        f_train = f_init[:,0] + noise
        f_val = f_init[:,1] + noise 
        f_test = f_init[:,2] + noise
        print(f_val.shape )
        print('train input--error, curl',compute_error_curl_np(f_train,f_true[:,0],b2))
        print('val input--error, curl',compute_error_curl_np(f_val,f_true[:,1],b2))
        print('test input--error, curl',compute_error_curl_np(f_test,f_true[:,2],b2))
        
        L1l = torch.tensor(l1d,dtype=torch.float32,device=device)
        L1u = torch.tensor(l1u,dtype=torch.float32,device=device)
        L1 = L1l+L1u 
        B1 = torch.tensor(b1,dtype=torch.float32,device=device)
        B2 = torch.tensor(b2,dtype=torch.float32,device=device)
        print(L1)
        sigma_update = activations[args.activations_update]
        sigma = activations[args.activations]
        sigma_update = sigma 
        hidden_features = args.hidden_features
        K = args.filter_order_snn
        K1 = args.filter_order_scnn_k1
        K2 = args.filter_order_scnn_k2
        model_name = args.model_name
        print(model_name)
        F_intermediate = []
        for i in range(args.layers):
            F_intermediate.append(hidden_features)
        F_out = hidden_features
        if (args.mlp_decoder == 'False' and model_name != 'mpsn'):
            F_out = 1
            readout = 'no_readout'
            # args.activations = 'id'
            sigma = activations[args.activations]

        if model_name in ['psnn','snn','scnn']:
            model = scnn(F_in=1, F_intermediate=F_intermediate, F_out=F_out, K=K, K1=K1, K2=K2, laplacian=L1, laplacian_l=L1l, laplacian_u=L1u, sigma=sigma, model_name=model_name)
        if model_name in ['mpsn']:
            model = mpsn(F_in=1, F_intermediate=F_intermediate, F_out=F_out, l1l=L1l, l1u=L1u, agg=args.aggregation,sigma_update=sigma_update, sigma=sigma, model_name=model_name)
        pred = MLPPredictor_forex(h_feats=hidden_features)   
        model.to(device)
        pred.to(device)

        print(count_parameters(model))

        lr = args.learning_rate
        # define the loss and optimizer 
        optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr =lr) 
            
        f_train = torch.tensor(f_train,dtype=torch.float32,device=device)
        f_train.resize_(f_train.shape[0],1)
        f_val = torch.tensor(f_val,dtype=torch.float32,device=device)
        f_val.resize_(f_val.shape[0],1)
        f_test = torch.tensor(f_test,dtype=torch.float32,device=device)
        f_test.resize_(f_test.shape[0],1)
        f_true_train = torch.tensor(f_true[:,0],dtype=torch.float32,device=device)
        f_true_train.resize_(f_true_train.shape[0],1)
        f_true_val = torch.tensor(f_true[:,1],dtype=torch.float32,device=device)
        f_true_val.resize_(f_true_val.shape[0],1)
        f_true_test = torch.tensor(f_true[:,2],dtype=torch.float32,device=device)
        f_true_test.resize_(f_true_test.shape[0],1)
        mask = torch.tensor(mask,dtype=torch.float32,device=device)
        
        loss_path = r'./loss_files/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations
        if args.mlp_decoder == 'False' and model_name != 'mpsn':
            loss_path = r'./loss_files/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations +'_'+ readout
        if not os.path.exists(loss_path):
            os.makedirs(loss_path)

        if model_name == 'snn':
            losslogf = open(loss_path+"/%s_%dlayers_%dorders_%dfeatures_%drlz.txt" %(model_name,args.layers,K,hidden_features,rlz_id),"w")
        elif model_name == 'scnn':
            losslogf = open(loss_path+"/%s_%dlayers_%d_%dorders_%dfeatures_%drlz.txt" %(model_name,args.layers,K1,K2,hidden_features,rlz_id),"w")
        elif model_name == 'psnn':
            losslogf = open(loss_path+"/%s_%dlayers_%dfeatures_%drlz.txt" %(model_name,args.layers,hidden_features,rlz_id),"w")
        elif model_name in ['mpsn']:
            losslogf = open(loss_path+"/%s_%s_%dlayers_%dfeatures_%drlz.txt" %(args.aggregation,model_name,args.layers,hidden_features,rlz_id),"w")


        best_val = 1e6
        '''
        training
        '''
        # save the model
        model_path = r'./model_nn/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations
        affix = '%s_%drlz'%(args.activations,rlz_id)
        if args.mlp_decoder == 'False' and model_name != 'mpsn':
            model_path =  r'./model_nn/' +args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations  +'_'+ readout
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        if model_name == 'snn':
            path_affix = '/%dlayers_%dorders_%dfeatures'%(args.layers,K,hidden_features)     
        elif model_name == 'scnn':
            path_affix = '/%dlayers_%d_%dorders_%dfeatures'%(args.layers,K1,K2,hidden_features) 
        elif model_name == 'psnn':
            path_affix = '/%dlayers_%dfeatures'%(args.layers,hidden_features)
        elif model_name == 'mpsn':
            path_affix = '/%s_%dlayers_%dfeatures'%(args.aggregation,args.layers,hidden_features)

        model_save_path = model_path+path_affix+affix

        for e in range(args.epochs):
            if args.mlp_decoder == 'False' and model_name != 'mpsn':
                f_hat_train = model(f_train)
            else:    
                h = model(f_train)
                f_hat_train = pred(h)
    
            train_loss = compute_loss_forex(f_hat_train,f_true_train,mask,B2)
            train_error,train_curl = compute_error_curl(f_hat_train,f_true_train,B2)  
            
            '''val'''
            with torch.no_grad():
                if args.mlp_decoder == 'False' and model_name != 'mpsn':
                    f_hat_val = model(f_val)
                else:
                    h = model(f_val)
                    f_hat_val = pred(h)
                vloss = compute_loss_forex(f_hat_val,f_true_val,mask,B2)
                verror,vcurl = compute_error_curl(f_hat_val,f_true_val,B2) 
                if vloss < best_val:
                    best_val = vloss
                    print('curent epoch:', e)
                    torch.save({'model': model.state_dict(),'pred':pred.state_dict()},model_save_path)
                    losslogf.write("model updated at epoch %d \n" %(e))
                print('epoch {},\n train loss: {}, val loss: {}, train error, {}, v_error: {}, train curls: {}, val curls: {}'.format(e, train_loss, vloss, train_error, verror,train_curl,vcurl))
    
            losslogf.write("epoch %d, \n train loss: %f, val loss: %f, train error %f, val error %f, train curl %f, val curl %f \n  " %(e, train_loss.item(), vloss, train_error, verror, train_curl, vcurl))
            losslogf.flush()
            '''
            backward
            '''
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step() 
            
        '''
        testing
        '''       
        model.load_state_dict(torch.load(model_save_path,map_location=device)['model'],strict=False)
        pred.load_state_dict(torch.load(model_save_path,map_location=device)['pred'],strict=False)
        
        with torch.no_grad():     
            if args.mlp_decoder == 'False' and model_name != 'mpsn': 
                f_hat_test = model(f_test) 
            else:
                h = model(f_test) 
                f_hat_test = pred(h)
            print(f_hat_test)
            test_error,test_curl = compute_error_curl(f_hat_test,f_true_test,B2)        
            
            losslogf.write("test results, \n test error: %f, test curl %f \n  " %(test_error, test_curl))
            losslogf.flush()
            print('val error, test error',verror,test_error)  
            losslogf.close()
            print('v_curl: ',vcurl,'test curls: ',test_curl)   
        all_errors.append(test_error.detach().cpu().numpy())
        all_curls.append(test_curl.detach().cpu().numpy())
        
    '''
    writing the average performance over 10 runs
    '''

    if args.model_name == 'mpsn':
        finalresults_path = './final_results/'+args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations+'_'+ args.aggregation
    elif args.mlp_decoder == 'False' and model_name != 'mpsn': 
        finalresults_path = './final_results/'+args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations +'_'+ readout 
    else:
        finalresults_path = './final_results/'+args.noise_type +'_noise' + '/' + model_name +'_'+ args.activations
    if not os.path.exists(finalresults_path):
        os.makedirs(finalresults_path)
    with open(finalresults_path+path_affix+'.pkl','wb') as f_results:
        pickle.dump([all_errors,all_curls],f_results)

if __name__ == "__main__":
    main()