import os
import csv
import sys
import json
import random

def save_json(obj, path):
    '''
        Write a list of json object to the specified path.
    '''
    if os.path.isfile(path):
        reply = str(input('File already exsits. Overwrite? (y/n):'))
        os.system("rm {}".format(path))

    with open(path, 'w') as f:
        for d in obj:
            json.dump(d, f)
            f.write(os.linesep)


seabird = ['albatross', 'auklet', 'cormorant', 'frigatebird', 'fulmar', 'gull',
           'jaeger', 'kittiwake', 'pelican', 'puffin', 'tern']

waterfowl = ['gadwall', 'grebe', 'mallard', 'merganser', 'guillemot', 'pacific loon']


split = {
    'waterfowl': [],
    'seabird': [],
    'land': [],
}

for dir_name in os.listdir('./'):
    if not os.path.isdir(dir_name):
        continue

    label = 'land'
    for sb in seabird:
        if sb in dir_name.lower():
            label = 'seabird'

    for wf in waterfowl:
        if wf in dir_name.lower():
            if label == 'seabird':
                print('error')
                print(dir_name)
            label = 'waterfowl'

    split[label].append(dir_name)

# sample different land type birds to create two tasks
random.seed(1)
random.shuffle(split['land'])

print('Task Waterfowl vs Land')

print('\n\nTask Seabird vs Land')

task_dict = {
    'water': split['waterfowl'] + split['land'][:len(split['waterfowl'])],
    'seabird': split['seabird'] + split['land'][len(split['waterfowl']):len(split['waterfowl'])+len(split['seabird'])]
}

for task_name, task_classes in task_dict.items():
    print('Generating task {}'.format(task_name))
    print(task_classes)
    print(len(task_classes))
    data_dict = {
        '0_0': [],
        '0_1': [],
        '1_0': [],
        '1_1': [],
    }
    class_checklist = dict(zip(task_classes, [0 for _ in
                                              range(len(task_classes))]))

    with open(sys.argv[1], newline='') as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',')
        for row in reader:
            dir_name = row['img_filename'][:row['img_filename'].find('/')]
            if dir_name in task_classes:
                class_checklist[dir_name] = 1
                k = '{}_{}'.format(row['y'], row['place'])
                data_dict[k].append({
                    'x': row['img_filename'],
                    'y': int(row['y']),
                    'c': int(row['place']),
                })

    for k, v in class_checklist.items():
        if v == 0:
            print('FATAL ERROR')
            print(k)

    # write data out
    for k, v in data_dict.items():
        print(k, ':  ', len(v))
        random.shuffle(data_dict[k])

    # write data out
    # add positive correlation
    min_data_size = min([len(v) for v in data_dict.values()])
    env3 = []
    for k, v in data_dict.items():
        env3 += v[:min_data_size//3]
        data_dict[k] = v[min_data_size//3:]
    print(len(env3))


    env0 = data_dict['0_0'][:len(data_dict['0_0'])//3] +\
        data_dict['1_1'][:len(data_dict['1_1'])//3]

    env1 = data_dict['0_0'][len(data_dict['0_0'])//3:len(data_dict['0_0'])//3*2] +\
        data_dict['1_1'][len(data_dict['1_1'])//3:len(data_dict['1_1'])//3*2]

    env2 = data_dict['0_0'][len(data_dict['0_0'])//3*2:] +\
        data_dict['1_1'][len(data_dict['1_1'])//3*2:]

    # add negative correlation
    env0_pos = len(env0)
    env0 = env0 + data_dict['0_1'][:len(data_dict['0_1'])//5] +\
        data_dict['1_0'][:len(data_dict['1_0'])//5]
    env0_neg = len(env0) - env0_pos
    print(env0_pos / env0_neg)
    print(len(env0))

    env1_pos = len(env1)
    env1 = env1 + data_dict['0_1'][len(data_dict['0_1'])//5:3*len(data_dict['0_1'])//5] +\
        data_dict['1_0'][len(data_dict['1_0'])//5:3*len(data_dict['1_0'])//5]
    env1_neg = len(env1) - env1_pos
    print(env1_pos / env1_neg)
    print(len(env1))

    env2_pos = len(env2)
    env2 = env2 + data_dict['0_1'][3*len(data_dict['0_1'])//5:5*len(data_dict['0_1'])//5] +\
        data_dict['1_0'][3*len(data_dict['1_0'])//5:5*len(data_dict['1_0'])//5]
    env2_neg = len(env2) - env2_pos
    print(env2_pos / env2_neg)
    print(len(env2))

    print('\n\n')

    # # # save data to json
    # save_json(env1, '{}_env_train1.json'.format(task_name))
    # save_json(env0, '{}_env_train2.json'.format(task_name))
    # save_json(env2, '{}_env_val.json'.format(task_name))
    # save_json(env3, '{}_env_test.json'.format(task_name))
#
#['090.Red_breasted_Merganser', '058.Pigeon_Guillemot', '051.Horned_Grebe', '050.Eared_Grebe', '087.Mallard', '053.Western_Grebe', '046.Gadwall', '089.Hooded_Merganser', '052.Pied_billed_Grebe', '170.Mourning_Warbler', '105.Whip_poor_Will', '009.Brewer_Blackbird', '179.Tennessee_Warbler', '199.Winter_Wren', '111.Loggerhead_Shrike', '161.Blue_winged_Warbler', '132.White_crowned_Sparrow', '043.Yellow_bellied_Flycatcher']
#0_0 :   362
#0_1 :   164
#1_0 :   178
#1_1 :   360
#216
#4.434782608695652
#250
#2.1702127659574466
#298
#2.1914893617021276
#300
#
#
#
#File already exsits. Overwrite? (y/n):y
#File already exsits. Overwrite? (y/n):y
#File already exsits. Overwrite? (y/n):y
#File already exsits. Overwrite? (y/n):y
#Generating task seabird
#['061.Heermann_Gull', '084.Red_legged_Kittiwake', '008.Rhinoceros_Auklet', '101.White_Pelican', '007.Parakeet_Auklet', '066.Western_Gull', '065.Slaty_backed_Gull', '044.Frigatebird', '088.Western_Meadowlark', '071.Long_tailed_Jaeger', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '023.Brandt_Cormorant', '001.Black_footed_Albatross', '102.Western_Wood_Pewee', '146.Forsters_Tern', '060.Glaucous_winged_Gull', '072.Pomarine_Jaeger', '003.Sooty_Albatross', '141.Artic_Tern', '059.California_Gull', '106.Horned_Puffin', '005.Crested_Auklet', '145.Elegant_Tern', '144.Common_Tern', '006.Least_Auklet', '045.Northern_Fulmar', '064.Ring_billed_Gull', '063.Ivory_Gull', '002.Laysan_Albatross', '147.Least_Tern', '142.Black_Tern', '143.Caspian_Tern', '100.Brown_Pelican', '062.Herring_Gull', '021.Eastern_Towhee', '176.Prairie_Warbler', '082.Ringed_Kingfisher', '155.Warbling_Vireo', '047.American_Goldfinch', '159.Black_and_white_Warbler', '197.Marsh_Wren', '037.Acadian_Flycatcher', '153.Philadelphia_Vireo', '123.Henslow_Sparrow', '041.Scissor_tailed_Flycatcher', '055.Evening_Grosbeak', '070.Green_Violetear', '014.Indigo_Bunting', '019.Gray_Catbird', '118.House_Sparrow', '151.Black_capped_Vireo', '182.Yellow_Warbler', '107.Common_Raven', '175.Pine_Warbler', '131.Vesper_Sparrow', '188.Pileated_Woodpecker', '185.Bohemian_Waxwing', '026.Bronzed_Cowbird', '187.American_Three_toed_Woodpecker', '183.Northern_Waterthrush', '083.White_breasted_Kingfisher', '040.Olive_sided_Flycatcher', '129.Song_Sparrow', '124.Le_Conte_Sparrow', '110.Geococcyx', '054.Blue_Grosbeak', '190.Red_cockaded_Woodpecker', '148.Green_tailed_Towhee', '103.Sayornis', '119.Field_Sparrow', '181.Worm_eating_Warbler']

