import argparse
import os
import pickle
import torch
import numpy as np

from utils.data import flatten_batch, prepare_datasets
from utils.evaluation import label_correctness
from utils.logging import log_and_print, print_program
from utils.training import process_batch
from sklearn.preprocessing import StandardScaler


def compute_pehe_score(best_program, y_scaler, test_data, test_labels, output_type, output_size, num_labels, device='cpu',verbose=False):
    log_and_print("\n")
    log_and_print("Evaluating program {} on TEST SET".format(print_program(best_program, ignore_constants=(not verbose))))
    
    att_true = 0
    att_estimated = 0

    test_input, test_output = test_data, test_labels
    true_vals = torch.tensor(test_output).to(device)
    predicted_vals1 = process_batch(best_program, test_input, output_type, output_size, device=device).detach().cpu().numpy()
    predicted_vals1 = y_scaler.inverse_transform(predicted_vals1.reshape(-1,1))
    for i in test_input:
        if i[0]==0:
            i[0]=1
        else:
            i[0]=0
    
    predicted_vals2 = process_batch(best_program, test_input, output_type, output_size, device=device).detach().cpu().numpy()
    predicted_vals2 = y_scaler.inverse_transform(predicted_vals2.reshape(-1,1))
    
    t_len = len(test_labels[:,0]==1)
    
    c_e_len = len(np.logical_and(test_labels[:,0]==0 , test_labels[:,1]==1))
    
    c_e_logic = np.logical_and(test_labels[:,0]==0 , test_labels[:,1]==1)
        
    att_true = (sum(test_labels[test_labels[:,0]==1][:,2])/t_len) - (sum(test_labels[c_e_logic][:,2])/c_e_len)
    
    att_pred = (sum(predicted_vals1[test_labels[:,0]==1][:])/t_len) - (sum(predicted_vals2[test_labels[:,0]==1][:])/t_len)[0]
        
    return abs(att_true - att_pred)

def test_set_eval(program, testset, output_type, output_size, num_labels, device='cpu', verbose=False):
    log_and_print("\n")
    log_and_print("Evaluating program {} on TEST SET".format(print_program(program, ignore_constants=(not verbose))))
    with torch.no_grad():
        test_input, test_output = map(list, zip(*testset))
        true_vals = torch.tensor(flatten_batch(test_output)).to(device)
        predicted_vals = process_batch(program, test_input, output_type, output_size, device)
        metric, additional_params = label_correctness(predicted_vals, true_vals, num_labels=num_labels)
    log_and_print("F1 score achieved is {:.4f}".format(1 - metric))
    log_and_print("Additional performance parameters: {}\n".format(additional_params))

def parse_args():
    parser = argparse.ArgumentParser()
    # Args for experiment setup
    parser.add_argument('--program_path', type=str, required=True,
                        help="path to program")

    # Args for data
    parser.add_argument('--train_data', type=str, required=True,
                        help="path to train data")
    parser.add_argument('--test_data', type=str, required=True, 
                        help="path to test data")
    parser.add_argument('--train_labels', type=str, required=True,
                        help="path to train labels")
    parser.add_argument('--test_labels', type=str, required=True, 
                        help="path to test labels")
    parser.add_argument('--input_type', type=str, required=True, choices=["atom", "list"],
                        help="input type of data")
    parser.add_argument('--output_type', type=str, required=True, choices=["atom", "list"],
                        help="output type of data")
    parser.add_argument('--input_size', type=int, required=True,
                        help="dimenion of features of each frame")
    parser.add_argument('--output_size', type=int, required=True, 
                        help="dimension of output of each frame (usually equal to num_labels")
    parser.add_argument('--num_labels', type=int, required=True, 
                        help="number of class labels")
    parser.add_argument('--normalize', action='store_true', required=False, default=False,
                        help='whether or not to normalize the data')

    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()

    # Load program
    assert os.path.isfile(args.program_path)
    program = pickle.load(open(args.program_path, "rb"))

    # Load test set
    train_data = np.load(args.train_data)
    test_data = np.load(args.test_data)
    train_labels = np.load(args.train_labels)
    test_labels = np.load(args.test_labels)
    batched_trainset, validset, testset = prepare_datasets(train_data, None, test_data, train_labels, None, test_labels, normalize=args.normalize)

    # TODO allow user to choose device
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'

    test_set_eval(program, testset, args.output_type, args.output_size, args.num_labels, device=device, verbose=False)
