#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz

This file generates objects with train and testing split information for each
dataset. Each dataset has a predefined train and test partition. For more info
on the partitions, please see the file <datasetSelections.py>
'''

import os
import sys
import pickle
import numpy as np

sys.path.append('..')
import helperfunctions.CurriculumLib as CurLib
from helperfunctions.CurriculumLib import DataLoader_riteyes

path2data = r'D:\Soft\UnityEyes_Windows\UnityEyes_Windows'
path2h5 = os.path.join(path2data, 'All')
keepOld = False

DS_sel = pickle.load(open('dataset_selections.pkl', 'rb'))
AllDS = CurLib.readArchives(os.path.join(path2data, 'MasterKey'))
#list_ds = ['OpenEDS','sequence', 'S']
list_ds = ['UEGaze']

args={}
args['train_data_percentage'] = 1.0
args['net_ellseg_head'] =False
args['loss_w_rend_pred_2_gt_edge'] = 0.1
args['loss_w_rend_gt_2_pred'] = 0.1
args['loss_w_rend_pred_2_gt'] = 0.0
args['net_ellseg_head'] = 0.0

# Generate objects per dataset
for setSel in list_ds:
    # Train object
    AllDS_cond = CurLib.selSubset(AllDS, DS_sel['train'][setSel])
    dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=True)
    trainObj = DataLoader_riteyes(dataDiv_obj, path2h5, 'train', True, (480, 640), 
                                  scale=0.5, num_frames=4, args=args)
    validObj = DataLoader_riteyes(dataDiv_obj, path2h5, 'valid', False, (480, 640), 
                                  scale=0.5, num_frames=4,args=args)
    # # Test object
    # AllDS_cond = CurLib.selSubset(AllDS, DS_sel['test'][setSel])
    # dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=True)
    # testObj = DataLoader_riteyes(dataDiv_obj, path2h5, 'test', False, (480, 640),
    #                              scale=0.5, num_frames=4, args=args)

    for vid_id in np.unique(trainObj.imList[:, :, 1]):
        # 遍历验证集中的所有唯一视频ID
        # if vid_id in [ 10, 11,  23, 24,  33, 34,  43, 44,  55, 56,  67, 68,  79, 80,  92, 93,  105, 106,  118, 119,  130, 131,  143, 144, 149, 150,  161, 162]:
        # if vid_id in [163, 164, 165, 166, 167,168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,179, 180, 181, 182, 183, 184, 185]:
        # if vid_id in [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 50, 53, 56, 59, 62, 65]:
        if vid_id in [8, 9, 10, 12]:
            # 如果视频ID在训练集中出现过
            print(f'keep :{vid_id}')
            # 打印被丢弃的验证集重叠视频ID
            bad_ids = ((trainObj.imList[:, :, 1] == vid_id).sum(axis=-1) > 0)
            # 找出所有包含该视频ID的验证集样本
            trainObj.imList = trainObj.imList[~bad_ids]
    train_vid_ids = list(np.unique(trainObj.imList[:, :, 1]))
    print("--------1---------")
    print(train_vid_ids)

    for vid_id in np.unique(validObj.imList[:, :, 1]):
        # 遍历验证集中的所有唯一视频ID
        if vid_id in train_vid_ids:
            # 如果视频ID在训练集中出现过
            print(f'Discarded valid overlap video_id:{vid_id}')
            # 打印被丢弃的验证集重叠视频ID
            bad_ids = ((validObj.imList[:, :, 1] == vid_id).sum(axis=-1) > 0)
            # 找出所有包含该视频ID的验证集样本
            validObj.imList = validObj.imList[~bad_ids]


    print("--------2---------")
    val_vid_ids = list(np.unique(validObj.imList[:, :, 1]))
    print(val_vid_ids)

    print("--------3---------")

    if setSel == 'S':
        path2save = os.path.join(os.getcwd(), 'one_vs_one', 'cond_'+'OpenEDS_S'+'.pkl')
    else:
        path2save = os.path.join(os.getcwd(), 'one_vs_one', 'cond_'+setSel+'.pkl')
    if os.path.exists(path2save) and keepOld:
        print('Preserving old selections ...')

        # This ensure that the original selection remains the same
        trainObj_orig, validObj_orig, testObj_orig = pickle.load(open(path2save, 'rb'))
        trainObj.imList = trainObj_orig.imList
        validObj.imList = validObj_orig.imList
        pickle.dump((trainObj, validObj, validObj), open(path2save, 'wb'))
    else:
        print('Save data')
        pickle.dump((trainObj, validObj, validObj), open(path2save, 'wb'))

    print("len:", trainObj.imList.shape[0])
    print("len:", validObj.imList.shape[0])