from asyncio import as_completed
from importlib.util import LazyLoader
import torch
import matplotlib.pyplot as plt
import sys

nb_samples = sys.argv[1]
name1 = sys.argv[2]
name2 = sys.argv[3]

streme_performances = torch.empty(9,0)
seism_performances = torch.empty(9,0)
meme_performances = torch.empty(9,0)
ckn_performances = torch.empty(9,0)

for i in range(1,101):
    streme_performances = torch.cat((streme_performances, torch.load('STREME_ANALYSIS_FOR_'+nb_samples+'_SAMPLES/DATASET_'+str(i)+'/STREME_ANALYSIS_torch_tensor.pt').view(9,1)), dim=1)
    seism_performances = torch.cat((seism_performances, torch.load('SEISM_ANALYSIS_FOR_'+nb_samples+'_SAMPLES/DATASET_'+str(i)+'/SEISM_ANALYSIS_torch_tensor.pt').view(9,1)), dim=1)
    meme_performances = torch.cat((meme_performances, torch.load('MEME_ANALYSIS_FOR_'+nb_samples+'_SAMPLES/DATASET_'+str(i)+'/MEME_ANALYSIS_torch_tensor.pt').view(9,1)), dim=1)
    ckn_performances = torch.cat((ckn_performances, torch.load('CKN_ANALYSIS_FOR_'+nb_samples+'_SAMPLES/DATASET_'+str(i)+'/CKN_ANALYSIS_torch_tensor.pt').view(9,1)), dim=1)

streme_mean = torch.mean(streme_performances, dim=1)
streme_std = torch.std(streme_performances, dim=1)

seism_mean = torch.mean(seism_performances, dim=1)
seism_std = torch.std(seism_performances, dim=1)

meme_mean = torch.mean(meme_performances, dim=1)
meme_std = torch.std(meme_performances, dim=1)

ckn_mean = torch.mean(ckn_performances, dim=1)
ckn_std = torch.std(ckn_performances, dim=1)

x = torch.Tensor([2,3,4,5,6,7,8,9,10])
plt.errorbar(x, seism_mean, seism_std, c = 'r', label = "SEISM")
plt.errorbar(x, streme_mean, streme_std, c = 'b', label = 'STREME')
plt.errorbar(x, meme_mean, meme_std, c = 'g', label = "MEME")
plt.errorbar(x, ckn_mean, ckn_std, c='k', label = "CKN")
plt.axis([1.9, 10, 0, 100])
plt.xlabel('Motif accuracy score (-log10 TOMTOM p-values)')
plt.ylabel('Succes Rate (%)')
plt.legend()
plt.savefig(name1)
plt.clf()

score_2 = seism_performances[0,:].unsqueeze(0)
score_2 = torch.cat((score_2, streme_performances[0,:].unsqueeze(0)), dim=0)
score_2 = torch.cat((score_2, meme_performances[0,:].unsqueeze(0)), dim=0)
score_2 = torch.cat((score_2, ckn_performances[0,:].unsqueeze(0)), dim=0)
plt.ylabel('Succes Rate (%) for a TOMTOM score >= 2')
plt.boxplot(score_2, labels = ["SEISM", "STREME", "MEME", "CKN"])
plt.axis([0.5,4.5, 0, 100])
plt.savefig(name2)
