#main for flourier10 SMD group 2
#epoch=1
from CoLLaTe import detectAnomaly
from Utils.utils import readData,insertAnomaly,normaliseData,syntheticData,insertAnomaly2,insertNormality,readPublicData,InvalidData,readMutshangData
import numpy as np
import os
import argparse
import torch
import random
from extractValue import funcExtractValue

parser = argparse.ArgumentParser(description='[VAE]')
parser.add_argument('--fealen',type=int, required=False, default=17, help='feature length')
parser.add_argument('--winlen',type=int,required=False,default=4, help='window length')#4/40
parser.add_argument('--moduleNum',type=int,required=False,default=6, help="diff layers")#6/6
parser.add_argument('--loadMod',type=int,required=False,default=0, help='load model from file')
parser.add_argument('--device',type=str,required=False,default='cpu', help='train on which device')
parser.add_argument('--batchSize',type=int,required=False,default=100, help='batch size')
parser.add_argument('--needTrain',type=int,required=False,default=1, help='need train or just inference')
parser.add_argument('--trlr',type=float,required=False,default=0.001, help='learning rate')
parser.add_argument('--colr',type=float,required=False,default=0.001, help='learning rate')
parser.add_argument('--colrs',type=list,required=False,default=[0.001,0.0001,0.0001,0.0001,0.001])
parser.add_argument('--patchSize',type=int,required=False,default=2, help='attention patch length')#2/3
parser.add_argument('--klen',type=int,required=False,default=2, help='convolution kernel length')#2/5
parser.add_argument('--lamda',type=float,required=False,default=0.1, help='loss function weight')
parser.add_argument("--sigma",type=float,required=False,default=0.8, help='the ratio between attRes and gateIncrease')
parser.add_argument("--slidewinlen",type=int,required=False,default=2, help='the sliding window average length')#3/2
parser.add_argument('--exeTime',type=list,required=False, default=[5,15,25,35,55,90,130,170,210,255,305,355,405,455])
parser.add_argument('--edges',type=list,required=False, default=[0,10,20,30,40,70,110,150,190,230,280,330,380,430])
parser.add_argument('--invalid',type=list,required=False,default=[31])
parser.add_argument('--omit',type=int,required=False,default=0,help='whether omit the first column of data')
parser.add_argument('--needBreak',type=int,required=False,default=0)
parser.add_argument('--needCollaborate',type=int,required=False,default=1)
parser.add_argument('--LLMDistribution',type=list,required=False,default=[[1.,0.0,0.3104133018866854]])
parser.add_argument('--lambdas',type=list,required=False,default=[[1,1,1,0.01,0.7],[1,0,1,0.01,0.7],[0.5,0.5,1,0.1,0.7],[0.5,0.5,1,1,0.9],[0.5,0.5,1,1,0.8]])
parser.add_argument('--smallLoss',type=str,required=False,default="loss2")
parser.add_argument('--needRecons',type=int,default=0,required=False)
parser.add_argument('--earlyQuit',type=int,default=3,required=False)
parser.add_argument('--earlyQuits',type=list,default=[2,5,3,3,3])
args = parser.parse_args()


args.edges=[0,5,10,20,30,40,70,110,150,190,230,280,330,380,430,900,1200,9000]
exeTime=[]
for i in range(len(args.edges)-1):
    exeTime.append((args.edges[i]+args.edges[i+1])/2)
exeTime[-1]=1230*60
args.exeTime=exeTime
args.omit=False

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def divide(datas,ratio):
    leng=len(datas)
    len1=int(ratio*leng)
    return datas[:len1],datas[len1:]

def concatDatasLabels(datas,labels,number):
    dimNum=datas.shape[2]
    datas=datas[:number].reshape((-1,dimNum))
    labels=labels[:number].reshape(-1)
    return datas,labels

def makeMaskFromExample(examples,responsesSet,LLMwinlen):
    resLen=len(responsesSet)
    mask=np.zeros(resLen)
    for i in range(0,len(examples),2):
        mask[int(examples[i,0]+i/2*LLMwinlen)]=1.
        mask[int(examples[i+1,0]+i/2*LLMwinlen)]=1.
    return mask

def getLLMDistribution(responseSet):
    response2=-responseSet
    newResponse=np.concatenate([responseSet,response2])
    meanv=newResponse.mean()
    stdv=newResponse.std()
    return  meanv,stdv

seed = 1282028438#getSeed()#
setup_seed(seed)
datas,labels=readMutshangData("./data/mutshangData.npy","./data/mutshangLabel.npy")
print(datas.shape,labels.shape)
responsesSet=[]
files=os.listdir("./Response/Mustang/response/")
files=sorted(files)
for file in files:
    path="./Response/Mustang/response/"+file
    if not os.path.isdir(path):
        responsesSet+=funcExtractValue(path)
example_files=os.listdir("./Response/Mustang/examples/")
example_files=sorted(example_files)
examples=[]
for example_file in example_files:
    example=np.loadtxt("./Response/Mustang/examples/"+example_file,delimiter=',')
    examples.append(example)
examples=np.concatenate(examples,axis=0)
mask=makeMaskFromExample(examples,responsesSet,50) # mask examples given in prompt out when testing model performance
responsesSet=np.array(responsesSet)
meanv,stdv=getLLMDistribution(responsesSet)
args.LLMDistribution=[[1,meanv,stdv]]
responsesSet=np.reshape(responsesSet,(5,-1))
mask=np.reshape(mask,(5,-1))

dataset="SMD"
print(datas.shape)
results=[]
F1s=[]
scoress=[]

for i in range(len(responsesSet)):#len(responsesSet)
    args.colr=args.colrs[i]
    args.earlyQuit=args.earlyQuits[i]
    datasetID=dataset+str(i)+"_"
    tr_val,test=divide(datas[i],0.5)
    tr_valr,testr=divide(responsesSet[i],0.5)
    _,testMask=divide(mask[i],0.5)
    _,testLabels=divide(labels[i],0.5)
    trains,valdatas=divide(tr_val,0.8)
    trainr,valdatar=divide(tr_valr,0.8)
    if (testLabels==0.).all():
        test,tr_val=divide(datas[i],0.5)
        testr,tr_valr=divide(responsesSet[i],0.5)
        testMask,_ = divide(mask[i], 0.5)
        testLabels,_=divide(labels[i],0.5)
        trains,valdatas=divide(tr_val,0.8)
        trainr,valdatar=divide(tr_valr,0.8)

    trains=normaliseData(trains,args.omit)
    valdatas=normaliseData(valdatas,args.omit)
    testdatas=normaliseData(test,args.omit)
    if args.needBreak:
        precision,recall,F1,scores=detectAnomaly(trains,valdatas,testdatas,testLabels,trainr,valdatar,testr,args,datasetID,args.lambdas[i],testMask)
        scoress.append(scores)
        results.append([precision,recall,F1])
    elif args.needCollaborate:
        result=detectAnomaly(trains,valdatas,testdatas,testLabels,trainr,valdatar,testr,args,datasetID,args.lambdas[i],testMask,i)
        results.append(result)
if args.needBreak:
    results=np.array(results)
    np.savetxt("resultsBreak.csv",results,fmt='%f',delimiter=',',newline='\n')
elif args.needCollaborate:
    results = np.array(results)
    averageRes=results.mean(axis=0)
    print("average result (p,r,f1):",averageRes)
    np.savetxt("resultsCollaborate.csv",results,fmt='%f',delimiter=',',newline='\n')