#main for flourier10 SMD group 2
#epoch=1
from CoLLaTe import detectAnomaly
import numpy as np
import os
import argparse
import torch
import random
from extractValue_MGAB import funcExtractValue



parser = argparse.ArgumentParser(description='[VAE]')
parser.add_argument('--fealen',type=int, required=False, default=1,help='feature length')
parser.add_argument('--winlen',type=int,required=False,default=5,help='window length')#4/40
parser.add_argument('--moduleNum',type=int,required=False,default=10,help="diff layers")#6/6
parser.add_argument('--loadMod',type=int,required=False,default=0,help='load model from file')
parser.add_argument('--device',type=str,required=False,default='cpu',help='train on which device')
parser.add_argument('--batchSize',type=int,required=False,default=100,help='batch size')
parser.add_argument('--needTrain',type=int,required=False,default=1,help='need train or just inference')
parser.add_argument('--trlr',type=float,required=False,default=0.01,help='learning rate for train')
parser.add_argument('--colr',type=float,required=False,default=0.01,help='learning rate for collaboration')
parser.add_argument('--colrs',type=list,required=False,default=[0.01,0.001,0.01,0.01,0.01,0.01,0.001,0.001,0.001,0.001,0.001])
parser.add_argument('--patchSize',type=int,required=False,default=2,help='attention patch length')#2/3
parser.add_argument('--klen',type=int,required=False,default=2,help='convolution kernel length')#2/5
parser.add_argument('--lamda',type=float,required=False,default=0.01,help='loss function weight')
parser.add_argument("--sigma",type=float,required=False,default=0.8,help='the ratio between attRes and gateIncrease')
parser.add_argument("--slidewinlen",type=int,required=False,default=2,help='the sliding window average length')#3/2
parser.add_argument('--exeTime',type=list,required=False,default=[5,15,25,35,55,90,130,170,210,255,305,355,405,455])
parser.add_argument('--edges',type=list,required=False,default=[0,10,20,30,40,70,110,150,190,230,280,330,380,430])
parser.add_argument('--invalid',type=list,required=False,default=[31])
parser.add_argument('--omit',type=int,required=False,default=0,help='whether omit the first column of data')
parser.add_argument('--needBreak',type=int,required=False,default=0)
parser.add_argument('--needCollaborate',type=int,required=False,default=1)
parser.add_argument('--LLMDistribution',type=list,required=False,default=[[0.43888600349885126,0.753947677168626,0.23984450811102231],[0.5611122371454451,0.02875363717036958,0.041661028938830155]])
parser.add_argument('--lambdas',type=list,required=False,default=[[10,1,1,1,1.1],[1,5,1,1,1],[9,1,2,1,0.45],[7,1,1,1,1],[9,1,1,1,0.7],[1,1,1,1,1.5],[5,1,1,1,0.5],[10,1,1,1,0],[5,1,1,1,1], [5,1,1,1,0.7]])#)[[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1],[0.5,0.5,1,1,1]])
parser.add_argument('--smallLoss',type=str,required=False,default="standard")
parser.add_argument('--needRecons',type=int,default=1,required=False)
parser.add_argument('--earlyQuit',type=int,default=3,required=False)
args = parser.parse_args()


#mutshang data modify
args.edges=[0,5,10,20,30,40,70,110,150,190,230,280,330,380,430,900,1200,9000]
exeTime=[]
for i in range(len(args.edges)-1):
    exeTime.append((args.edges[i]+args.edges[i+1])/2)
exeTime[-1]=1230*60
args.exeTime=exeTime
args.omit=False

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def divide(datas,ratio):
    leng=len(datas)
    len1=int(ratio*leng)
    return datas[:len1],datas[len1:]

def concatDatasLabels(datas,labels,number):
    datas=datas[:number].reshape((-1))
    labels=labels[:number].reshape(-1)
    return datas,labels

def makeMaskFromExample(examples,responsesSet,LLMwinlen):
    resLen=len(responsesSet)
    mask=np.zeros(resLen)
    for i in range(0,len(examples),2):
        mask[int(examples[i,0]+i/2*LLMwinlen)]=1.
        mask[int(examples[i+1,0]+i/2*LLMwinlen)]=1.
    return mask

def getLLMDistribution(responseSet):
    response2=-responseSet
    newResponse=np.concatenate([responseSet,response2])
    meanv=newResponse.mean()
    stdv=newResponse.std()
    return  meanv,stdv

seed = 1282028438#getSeed()#
setup_seed(seed)
datas=np.load("./data/MGABdatas.npy")
labels=np.load("./data/MGABlabels.npy")
print(datas.shape,labels.shape)
responsesSet=[]
files=os.listdir("./Response/MGAB/response/")
files=sorted(files)
for file in files:
    path="./Response/MGAB/response/"+file
    responsesSet+=funcExtractValue(path)
example_files=os.listdir("./Response/MGAB/examples/")
example_files=sorted(example_files)
examples=[]
for example_file in example_files:
    example=np.loadtxt("./Response/MGAB/examples/"+example_file,delimiter=',')
    examples.append(example)
examples=np.concatenate(examples,axis=0)
mask=makeMaskFromExample(examples,responsesSet,50)
responsesSet=np.array(responsesSet)
meanv,stdv=getLLMDistribution(responsesSet)
args.LLMDistribution=[[1,meanv,stdv]]
responsesSet=np.reshape(responsesSet,(10,-1))
mask=mask.reshape((10,-1))
datas=datas.reshape((10,-1,1))

dataset="MGAB"
print(datas.shape,labels.shape,responsesSet.shape)
results=[]
F1s=[]
scoress=[]
for i in range(len(responsesSet)):#len(responsesSet)
    args.colr=args.colrs[i]
    datasetID=dataset+str(i)+"_"
    tr_val,test=divide(datas[i],0.5)
    tr_valr,testr=divide(responsesSet[i],0.5)
    _,testLabels=divide(labels[i],0.5)
    _, testMask = divide(mask[i], 0.5)
    trains,valdatas=divide(tr_val,0.8)
    trainr,valdatar=divide(tr_valr,0.8)
    if (testLabels==0.).all():
        test,tr_val=divide(datas[i],0.5)
        testr,tr_valr=divide(responsesSet[i],0.5)
        testMask,_ = divide(mask[i], 0.5)
        testLabels,_=divide(labels[i],0.5)
        trains,valdatas=divide(tr_val,0.8)
        trainr,valdatar=divide(tr_valr,0.8)

    if args.needBreak:
        precision,recall,F1,scores=detectAnomaly(trains,valdatas,test,testLabels,trainr,valdatar,testr,args,datasetID,args.lambdas[i],testMask)
        scoress.append(scores)
        results.append([precision,recall,F1])
    elif args.needCollaborate:
        result=detectAnomaly(trains,valdatas,test,testLabels,trainr,valdatar,testr,args,datasetID,args.lambdas[i],testMask)
        results.append(result)
if args.needBreak:
    results=np.array(results)
    averageRes = results.mean(axis=0)
    print("average result (p,r,f1):", averageRes)
    np.savetxt("resultsBreak.csv",results,fmt='%f',delimiter=',',newline='\n')
elif args.needCollaborate:
    results = np.array(results)
    averageRes=results.mean(axis=0)
    print("average result (p,r,f1):",averageRes)
    np.savetxt("resultsCollaborate.csv",results,fmt='%f',delimiter=',',newline='\n')