#main for flourier10 SMD group 2
#epoch=1
from model.ADSec import detectAnomaly
import numpy as np
import os
import argparse
import torch
import random
import time





parser = argparse.ArgumentParser(description='[VAE]')
parser.add_argument('--fealen',type=int, required=False, default=9,help='feature length')
parser.add_argument('--winlen',type=int,required=False,default=5,help='window length')#4/40
parser.add_argument('--moduleNum',type=int,required=False,default=10,help="diff layers")#6/6
parser.add_argument('--loadMod',type=int,required=False,default=0,help='load model from file')
parser.add_argument('--device',type=str,required=False,default='cpu',help='train on which device')
parser.add_argument('--batchSize',type=int,required=False,default=100,help='batch size')
parser.add_argument('--needTrain',type=int,required=False,default=1,help='need train or just inference')
parser.add_argument('--trlr',type=float,required=False,default=0.0001,help='learning rate for train')
parser.add_argument('--colr',type=float,required=False,default=0.01,help='learning rate for collaboration')
parser.add_argument('--colrs',type=list,required=False,default=[0.001,0.001,0.001,0.01,0.01,0.01,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001])
parser.add_argument('--patchSize',type=int,required=False,default=2,help='attention patch length')#2/3
parser.add_argument('--klen',type=int,required=False,default=2,help='convolution kernel length')#2/5
parser.add_argument('--lambdas',type=float,required=False,default=0.01,help='loss function weight')
parser.add_argument("--sigma",type=float,required=False,default=0.8,help='the ratio between attRes and gateIncrease')
parser.add_argument("--slidewinlen",type=int,required=False,default=2,help='the sliding window average length')#3/2
parser.add_argument('--exeTime',type=list,required=False,default=[5,15,25,35,55,90,130,170,210,255,305,355,405,455])
parser.add_argument('--edges',type=list,required=False,default=[0,10,20,30,40,70,110,150,190,230,280,330,380,430])
parser.add_argument('--invalid',type=list,required=False,default=[31])
parser.add_argument('--omit',type=int,required=False,default=0,help='whether omit the first column of data')
parser.add_argument('--needBreak',type=int,required=False,default=0)
parser.add_argument('--needTest',type=int,required=False,default=1)
parser.add_argument('--LLMDistribution',type=list,required=False,default=[[0.43888600349885126,0.753947677168626,0.23984450811102231],[0.5611122371454451,0.02875363717036958,0.041661028938830155]])
parser.add_argument('--lambdass',type=list,required=False,default=[[1,0.1,1],[1,0.1,1],[1,0.1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1], [1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1]])#)[[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1]])
parser.add_argument('--smallLoss',type=str,required=False,default="standard")
parser.add_argument('--needRecons',type=int,default=1,required=False)
parser.add_argument('--earlyQuit',type=int,default=3,required=False)
parser.add_argument('--impactLength',type=int,default=5,required=False)
parser.add_argument('--epoch',type=int,default=2,required=False)
parser.add_argument('--epochs',type=list,default=[7,7,7,7,10,1,10,7,1,3,3,10],required=False)
parser.add_argument('--lrs',type=list,default=[0.001,0.001,0.001,0.001,0.0001,0.0001,0.0001,0.001,0.0001,0.001,0.001,0.0001],required=False)

args = parser.parse_args()

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def divide(datas,ratio):
    leng=len(datas)
    len1=int(ratio*leng)
    return datas[:len1],datas[len1:]

def concatDatasLabels(datas,labels,number):
    datas=datas[:number].reshape((-1))
    labels=labels[:number].reshape(-1)
    return datas,labels

def makeMaskFromExample(examples,responsesSet,LLMwinlen):
    resLen=len(responsesSet)
    mask=np.zeros(resLen)
    for i in range(0,len(examples),2):
        mask[int(examples[i,0]+i/2*LLMwinlen)]=1.
        mask[int(examples[i+1,0]+i/2*LLMwinlen)]=1.
    return mask

def getLLMDistribution(responseSet):
    response2=-responseSet
    newResponse=np.concatenate([responseSet,response2])
    meanv=newResponse.mean()
    stdv=newResponse.std()
    return  meanv,stdv

def getSeed():
    seed=int(time.time()*1000)%(2**32-1)
    return seed

seed = getSeed()
setup_seed(seed)

relaMat=np.load("relaMat.npy")
relaMat=torch.tensor(relaMat)
subsetNumber=12
dataPath=".\slowDiagnose\ADSec\\dataset\\gecco\\"
trdatas=[]
tsdatas=[]
tslabels=[]
for i in range(subsetNumber):
    if not os.path.exists(dataPath+"subset_%d_train.npy"%i):
        print(i)
        continue
    trdata=np.load(dataPath+"subset_%d_train.npy"%i)
    tsdata=np.load(dataPath+"subset_%d_test.npy"%i)
    tslabel=np.load(dataPath+"subset_%d_test_labels.npy"%i)
    trdatas.append(trdata)
    tsdatas.append(tsdata)
    tslabels.append(tslabel)


trdatas=np.stack(trdatas,axis=0)
tsdatas=np.stack(tsdatas,axis=0)
tslabels=np.stack(tslabels,axis=0)
print(trdatas.shape)



dataset="gecco"

results=[]
F1s=[]
scoress=[]

for i in range(len(trdatas)):
    args.colr=args.colrs[i]
    args.lambdas=args.lambdass[i]
    datasetID=dataset+str(i)+"_"
    args.trlr=args.lrs[i]
    args.epoch=args.epochs[i]

    tr_val=trdatas[i]
    trains, valdatas = divide(tr_val, 0.8)
    test=tsdatas[i]
    testLabels=tslabels[i]


    result=detectAnomaly(trains,valdatas,test,testLabels,args,datasetID,relaMat.clone(),i)
    results.append(result)

results=np.array(results)
averageRes = results.mean(axis=0)
print("average result (p,r,f1):", averageRes)
np.savetxt("resultsWE.csv",results,fmt='%f',delimiter=',',newline='\n')
