"""

## Section 4.2: Evaluating Subpopulation Shifts
Run the python script `dataset/subpopulation_shift_cat_dog_indoor_outdoor.py` to reproduce the MetaShift subpopulation shift dataset (based on Visual Genome images) in paper Appendix D. 
```sh
cd dataset/
python subpopulation_shift_cat_dog_indoor_outdoor.py
```
The python script generates a “Cat vs. Dog” dataset, where the general contexts “indoor/outdoor” have a natural spurious correlation with the class labels. 


The following files will be generated by executing the python script `dataset/subpopulation_shift_cat_dog_indoor_outdoor.py`. 

### Output files (mixed version: for reproducing experiments)

```plain
/data/MetaShift/MetaShift-subpopulation-shift
├── imageID_to_group.pkl
├── train/
    ├── cat/             (more cat(indoor) images than cat(outdoor))
    ├── dog/             (more dog(outdoor) images than cat(indoor)) 
├── val_out_of_domain/
    ├── cat/             (cat(indoor):cat(outdoor)=1:1)
    ├── dog/             (dog(indoor):dog(outdoor)=1:1) 
```
where `imageID_to_group.pkl` is a dictionary with 4 keys : 
`'cat(outdoor)'`, `'cat(outdoor)'`, `'dog(outdoor)'`, `'dog(outdoor)'`. 
The corresponding value of each key is the list of the names of the images that belongs to that subset. 
You can tune the `NUM_MINORITY_IMG` to control the amount of subpopulation shift.  

### Output files (unmixed version, for other potential uses)
To facilitate other potential uses, we also outputs an unmixed version, where we output the `'cat(outdoor)'`, `'cat(outdoor)'`, `'dog(outdoor)'`, `'dog(outdoor)'` into 4 seperate folders. 
```plain
/data/MetaShift/MetaShift-Cat-Dog-indoor-outdoor
├── imageID_to_group.pkl
├── train/
    ├── cat/             (all cat(indoor) images)
    ├── dog/             (all dog(outdoor) images) 
├── val_out_of_domain/
    ├── cat/             (all cat(outdoor) images)
    ├── dog/             (all dog(indoor) images) 
```

"""

CUSTOM_SPLIT_DATASET_FOLDER = '/data/MetaShift/MetaShift-Cat-Dog-indoor-outdoor'

SUBPOPULATION_SHIFT_DATASET_FOLDER = '/data/MetaShift/MetaShift-subpopulation-shift'


import pandas as pd 
import seaborn as sns

import pickle
import numpy as np
import json, re, math
from collections import Counter, defaultdict
from itertools import repeat
import pprint
import os, errno
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil # for copy files
import networkx as nx # graph vis
import pandas as pd
from sklearn.decomposition import TruncatedSVD

import Constants
IMAGE_DATA_FOLDER          = Constants.IMAGE_DATA_FOLDER

from generate_full_MetaShift import preprocess_groups, build_subset_graph, copy_image_for_subject


def print_communities(subject_data, node_name_to_img_id, trainsg_dupes, subject_str):
    ##################################
    # Community detection 
    ##################################
    G = build_subset_graph(subject_data, node_name_to_img_id, trainsg_dupes, subject_str)

    import networkx.algorithms.community as nxcom

    # Find the communities
    communities = sorted(nxcom.greedy_modularity_communities(G), key=len, reverse=True)
    # Count the communities
    print("The graph has {} communities.".format(len(communities)) )
    for community in communities:
        community_merged = set()
        for node_str in community:
            node_str = node_str.replace('\n', '')
            node_image_IDs = node_name_to_img_id[node_str]
            community_merged.update(node_image_IDs)
            # print(node_str , len(node_image_IDs), end=';')

        print('total size:',len(community_merged))
        community_set = set([ x.replace('\n', '') for x in community])
        print(community_set, '\n\n')
    return G 



def parse_dataset_scheme(dataset_scheme, node_name_to_img_id, exclude_img_id=set(), split='test'):
    """
    exclude_img_id contains both trainsg_dupes and test images that we do not want to leak 
    """
    community_name_to_img_id = defaultdict(set)
    all_img_id = set()

    ##################################
    # Iterate subject_str: e.g., cat
    ##################################
    for subject_str in dataset_scheme:        
        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        for community_name in dataset_scheme[subject_str]:
            ##################################
            # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
            ##################################
            for node_name in dataset_scheme[subject_str][community_name]:
                community_name_to_img_id[community_name].update(node_name_to_img_id[node_name] - exclude_img_id)
                all_img_id.update(node_name_to_img_id[node_name] - exclude_img_id)
            print(community_name, 'Size:', len(community_name_to_img_id[community_name]) )


        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        root_folder = os.path.join(CUSTOM_SPLIT_DATASET_FOLDER, split)
        copy_image_for_subject(root_folder, subject_str, dataset_scheme[subject_str], community_name_to_img_id, trainsg_dupes=set(), use_symlink=False) # use False to share 

    return community_name_to_img_id, all_img_id


def get_all_nodes_in_dataset(dataset_scheme):
    all_nodes = set()
    ##################################
    # Iterate subject_str: e.g., cat
    ##################################
    for subject_str in dataset_scheme:        
        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        for community_name in dataset_scheme[subject_str]:
            ##################################
            # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
            ##################################
            for node_name in dataset_scheme[subject_str][community_name]:
                all_nodes.add(node_name)
    return all_nodes

def generate_splitted_metadaset():

    if os.path.isdir(CUSTOM_SPLIT_DATASET_FOLDER): 
        shutil.rmtree(CUSTOM_SPLIT_DATASET_FOLDER) 
    os.makedirs(CUSTOM_SPLIT_DATASET_FOLDER, exist_ok = False)


    node_name_to_img_id, most_common_list, subjects_to_all_set, subject_group_summary_dict = preprocess_groups(output_files_flag=False)

    ##################################
    # Removing ambiguous images that have both cats and dogs 
    ##################################
    trainsg_dupes = node_name_to_img_id['cat(dog)'] # can also use 'dog(cat)'
    subject_str_to_Graphs = dict()


    for subject_str in ['cat', 'dog']:
        subject_data = [ x for x in subject_group_summary_dict[subject_str].keys() if x not in ['cat(dog)', 'dog(cat)'] ]
        # print('subject_data', subject_data)
        ##################################
        # Print detected communities in Meta-Graph
        ##################################
        G = print_communities(subject_data, node_name_to_img_id, trainsg_dupes, subject_str) # print detected communities, which guides us the train/test split. 
        subject_str_to_Graphs[subject_str] = G




    test_set_scheme = {
        'cat': {
            'cat(outdoor)': {
                'cat(car)',
                'cat(fence)', 'cat(grass)', 'cat(roof)', 'cat(bench)', 'cat(bird)', 'cat(house)', 
            },
        },
        'dog': {
            'dog(indoor)': {
                'dog(screen)', 'dog(shelf)', 'dog(desk)', 'dog(picture)', 'dog(laptop)',
                'dog(remote control)', 'dog(blanket)', 'dog(bed)', 'dog(sheet)', 'dog(lamp)', 'dog(books)', 'dog(pillow)', 'dog(curtain)', 
                'dog(container)', 'dog(table)', 'dog(cup)', 'dog(plate)', 'dog(food)', 'dog(box)',
                'dog(rug)', 'dog(floor)', 'dog(cabinet)', 'dog(towel)',
                'dog(bowl)',
                'dog(television)', 'dog(carpet)',
                'dog(sofa)',

            },    
        },    
    }

    train_set_scheme = {
        'cat': {
            'cat(indoor)': {
                'cat(speaker)', 'cat(computer)', 'cat(screen)', 'cat(laptop)', 'cat(computer mouse)', 'cat(keyboard)', 'cat(monitor)', 'cat(desk)',
                'cat(sheet)', 'cat(bed)', 'cat(blanket)', 'cat(remote control)', 'cat(comforter)', 'cat(pillow)', 'cat(couch)',
                'cat(books)', 'cat(book)', 'cat(television)', 'cat(bookshelf)', 'cat(blinds)',
                'cat(sink)', 'cat(bottle)', 'cat(faucet)', 'cat(towel)', 'cat(counter)',
                'cat(curtain)', 'cat(toilet)', 'cat(pot)', 
                'cat(carpet)', 'cat(toy)', 'cat(floor)',
                'cat(plate)', 'cat(rug)', 'cat(food)', 'cat(table)',
                'cat(box)', 'cat(paper)', 'cat(suitcase)', 'cat(bag)',
                'cat(container)', 'cat(vase)', 'cat(shelf)', 'cat(bowl)',
                'cat(picture)', 'cat(papers)', 'cat(lamp)',
                'cat(cup)', 'cat(sofa)', 
            },
        },
        'dog': {
            'dog(outdoor)': {
                'dog(house)', 'dog(grass)', 'dog(horse)', 'dog(fence)', 'dog(cow)', 'dog(sheep)', 'dog(dirt)',
                'dog(car)', 'dog(motorcycle)', 'dog(truck)', 'dog(helmet)', 
                'dog(snow)',
                'dog(flag)', 'dog(boat)', 'dog(rope)', 'dog(trees)', 'dog(frisbee)',
                'dog(bike)', 'dog(bicycle)', 
                'dog(sand)', 'dog(surfboard)', 'dog(water)', 
                'dog(fire hydrant)', 'dog(pole)', 
                'dog(skateboard)',
                'dog(bench)', 'dog(trash can)',
            },
        },
        
    }
    additional_test_set_scheme = dict() # empty dict 

    print('========== test set info ==========')
    test_community_name_to_img_id, test_all_img_id = parse_dataset_scheme(test_set_scheme, node_name_to_img_id, exclude_img_id=trainsg_dupes, split='test')
    # print('test_all_img_id', len(test_all_img_id))
    print('========== train set info ==========')
    train_community_name_to_img_id, train_all_img_id = parse_dataset_scheme(train_set_scheme, node_name_to_img_id, exclude_img_id=test_all_img_id.union(trainsg_dupes), split='train')
    # print('========== additional test set info ==========')
    # additional_test_community_name_to_img_id, additional_test_all_img_id = parse_dataset_scheme(additional_test_set_scheme, node_name_to_img_id, exclude_img_id=train_all_img_id.union(trainsg_dupes), split='test')

    """
    ** Simulating subpopulation shifts ** 

    Functionality: re-export dataset for domain generalization algorithms. 
    parameter 1: ratio: hardness of subpopulation shift 
    parameter 2: name of the output dir 

    the structure of the dataset dir should look as this: 
        train/
            cat/
                ...images
            dog/
                ...images
        test/
            cat/
                ...images
            dog/
                ...images
        imageID_to_group.pkl -- specifying group information. 
    """

    if os.path.isdir(SUBPOPULATION_SHIFT_DATASET_FOLDER): 
        shutil.rmtree(SUBPOPULATION_SHIFT_DATASET_FOLDER) 
    os.makedirs(SUBPOPULATION_SHIFT_DATASET_FOLDER, exist_ok = False)

    import random

    def shuffle_and_truncate(img_id_set, truncate=500):
        img_id_list = sorted(img_id_set)
        random.Random(42).shuffle(img_id_list)
        return img_id_list[:truncate]

    cat_outdoor_images = shuffle_and_truncate(test_community_name_to_img_id['cat(outdoor)'], 294) # 144 for test, 0-150 for training
    dog_indoor_images = shuffle_and_truncate(test_community_name_to_img_id['dog(indoor)'], 294) 
    cat_indoor_images = shuffle_and_truncate(train_community_name_to_img_id['cat(indoor)'], 994) # 144 for test, 450 for training, 
    dog_outdoor_images = shuffle_and_truncate(train_community_name_to_img_id['dog(outdoor)'], 994) 

    with open(SUBPOPULATION_SHIFT_DATASET_FOLDER + '/' + 'imageID_to_group.pkl', 'wb') as handle:
        imageID_to_group = dict()
        group_to_imageID = {
            'cat(outdoor)': cat_outdoor_images,
            'dog(indoor)': dog_indoor_images,
            'cat(indoor)': cat_indoor_images,
            'dog(outdoor)': dog_outdoor_images,
        }
        for group_str in group_to_imageID:
            for imageID in group_to_imageID[group_str]:
                imageID_to_group[imageID] = [group_str] 
        pickle.dump(imageID_to_group, file=handle)


    from sklearn.model_selection import train_test_split
    cat_outdoor_train, cat_outdoor_test, dog_indoor_train, dog_indoor_test = train_test_split(cat_outdoor_images, dog_indoor_images, test_size=144, random_state=42)
    cat_indoor_train, cat_indoor_test, dog_outdoor_train, dog_outdoor_test = train_test_split(cat_indoor_images, dog_outdoor_images, test_size=144, random_state=42)

    
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'val_out_of_domain', 'cat', use_symlink=True,
        img_IDs = cat_outdoor_test + cat_indoor_test,
        )
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'val_out_of_domain', 'dog', use_symlink=True,
        img_IDs = dog_indoor_test + dog_outdoor_test,
        )
    
    # plan: 850 training, 
    NUM_MINORITY_IMG = 100 # 10-150: 10, 50, 100, 150
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'train', 'cat', use_symlink=True,
        img_IDs = cat_outdoor_train[:NUM_MINORITY_IMG] + cat_indoor_train[NUM_MINORITY_IMG:],
        )
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'train', 'dog', use_symlink=True,
        img_IDs = dog_indoor_train[:NUM_MINORITY_IMG] + dog_outdoor_train[NUM_MINORITY_IMG:],
        )
    print('NUM_MINORITY_IMG', NUM_MINORITY_IMG)

    return


##################################
# Copy Images specified by IDs: 
# destination folder: os.path.join(root_folder, split, subset_str)
##################################
def copy_images(root_folder,  split, subset_str, img_IDs, use_symlink=True):
    ##################################
    # Create dataset a new folder 
    ##################################
    subject_localgroup_folder = os.path.join(root_folder, split, subset_str)
    if os.path.isdir(subject_localgroup_folder): 
        shutil.rmtree(subject_localgroup_folder) 
    os.makedirs(subject_localgroup_folder, exist_ok = False)

    for image_idx_in_set, imageID in enumerate(img_IDs): 

        src_image_path = IMAGE_DATA_FOLDER + imageID + '.jpg'
        dst_image_path = os.path.join(subject_localgroup_folder, imageID + '.jpg') 

        if use_symlink:
            ##################################
            # Image Copy Option B: create symbolic link
            # Usage: for local use, saving disk storge. 
            ##################################
            os.symlink(src_image_path, dst_image_path)
            # print('symlink:', src_image_path, dst_image_path)
        else:
            ##################################
            # Image Copy Option A: copy the whole jpg file
            # Usage: for sharing the meta-dataset
            ################################## 
            shutil.copyfile(src_image_path, dst_image_path)
            # print('copy:', src_image_path, dst_image_path)

    return 


if __name__ == '__main__':
    generate_splitted_metadaset()

