import numpy as np
import os
import sys
import pandas as pd
import argparse
#n_masks, index


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-m', help='masks', type=int, default=0)
    parser.add_argument('-r', help='repeats', type=int, default=12)
    parser.add_argument('-nn', help='model', type=str, default="mlp")
    parser.add_argument('-p', help='path', type=str,
                        default='./sensor_results/')
    parser.add_argument('-d', help='dataset', type=str, default='swat')
    args = parser.parse_args()
    print("Gathering for", args.m)
    masks = args.m
    path = args.p
    repeats = args.r
    model=args.nn
    dataset = args.d
    if 'gridworld' in dataset or "monopoly" in dataset:
        is_node = 'node' in dataset
        line_numbers = (-4,-1) if 'node' in dataset else (-11,-8)

        print(line_numbers)
        dataset=dataset.split("_")[0]
        files = [f"{model}_{dataset}_masks_{masks}_{i}/log_{dataset}_{model}_{masks}.txt" for i in range(repeats)]
        
        if is_node:
            results = {'Per node normalized loss Best F1-Score': [], 'Precision': [], 'Recall': [],"MAP":[]}
        else:
            results = {'Per graph max normalized loss Best F1-Score': [], 'Precision': [], 'Recall': [],"MAP":[]}
        
        
        
        
    else:
        files = [f"{model}_{dataset}_masks_{masks}_{i}/log_{dataset}_{model}_{masks}.txt" for i in range(repeats)]
        results = {'Best F1-Score': [], 'Precision': [], 'Recall': []}
        line_numbers = (-4,-1)
        #line_numbers = (-10,-7)
    
    for f in files:
        print(f)
        try:
            with open(os.path.join(path, f), 'r') as fl:
                if line_numbers[1] ==-1:
                    lines = fl.readlines()[line_numbers[0]:]
                else:
                    lines = fl.readlines()[line_numbers[0]:line_numbers[1]+1]
                print(lines)
                #print(f)
                for l in lines:
                    key, val = l.split(':')
                    val = val.strip().strip('\n')
                    
                    results[key].append(float(val))
        except:
            print("Not Done yet")

    print(results)
    df = pd.DataFrame(results).T
    print(df.to_csv())
    print(df.mean(axis=1))
    print(df.std(axis=1))
