import json
import typing
from pathlib import Path
from transformers import AutoTokenizer, pipeline,AutoModelForCausalLM
#from transformers import pipeline
import torch
from datasets import Dataset
with open("/data/ssliangruc3/data/zsre_test.json","r") as f:
    dataset=json.load(f)
alldata=[1,4,1,4,2,2,2,2,2,2,3,3,3,3,1,3,1,3,0,1,1,3,3,2,2,0,4,3,3,3,3,3,3,3,3,4,4
,0,0,1,0,0,0,2,1,2,1,3,2,4,4,4,4,3,4,0,0,2,2,1,1,1,1,1,3,3,3,4,0,3,2,3,3,3
,3,3,3,3,3,3,3,0,1,2,2,1,1,1,4,2,1,1,0,0,0,2,2,3,1,3,1,2,0,2,1,2,2,2,2,2,1
,1,3,1,4,1,1,2,3,3,3,4,0,1,1,1,1,4,2,0,0,0,2,0,0,0,2,2,2,0,0,0,0,2,2,2,2,4
,1,1,4,4,4,4,4,0,2,3,0,0,0,2,2,2,4,1,0,1,0,0,1,1,1,1,1,2,3,3,3,3,0,1,1,0,0
,4,4,4,4,4,4,3,3,3,1,4,4,4,0,4]
data0=[]
data1=[]
data2=[]
data3=[]
data4=[]
for i in range(200):
    if alldata[i]==0:
        data0.append(dataset[i])
    elif alldata[i]==1:
        data1.append(dataset[i])
    elif alldata[i]==2:
        data2.append(dataset[i])
    elif alldata[i]==3:
        data3.append(dataset[i])
    elif alldata[i]==4:
        data4.append(dataset[i])
with open("./data0.json",'w',encoding='utf-8') as f:
   for data in data0:
      json.dump(data,f)
      f.write('\n')
with open("./data1.json",'w',encoding='utf-8') as f:
   for data in data1:
      json.dump(data,f)
      f.write('\n')
with open("./data2.json",'w',encoding='utf-8') as f:
   for data in data2:
      json.dump(data,f)
      f.write('\n')
with open("./data3.json",'w',encoding='utf-8') as f:
   for data in data3:
      json.dump(data,f)
      f.write('\n')
with open("./data4.json",'w',encoding='utf-8') as f:
   for data in data4:
      json.dump(data,f)
      f.write('\n')
"""
with open("./data5.json",'w',encoding='utf-8') as f:
   for data in data5:
      json.dump(data,f)
      f.write('\n')

with open("./data6.json",'w',encoding='utf-8') as f:
   for data in data6:
      json.dump(data,f)
      f.write('\n')
with open("./data7.json",'w',encoding='utf-8') as f:
   for data in data7:
      json.dump(data,f)
      f.write('\n')

with open("./data8.json",'w',encoding='utf-8') as f:
   for data in data8:
      json.dump(data,f)
      f.write('\n')
with open("./data9.json",'w',encoding='utf-8') as f:
   for data in data9:
      json.dump(data,f)
      f.write('\n')
"""
