
"""

### Generate Full MetaShift

Since the total number of all subsets is very large, all of the following scripts only generate a subset of MetaShift. As specified in [dataset/Constants.py](./dataset/Constants.py), we only generate MetaShift for the following classes (subjects). You can add any additional classes (subjects) into the list. See [dataset/meta_data/class_hierarchy.json](./dataset/meta_data/class_hierarchy.json) for the full object vocabulary and its hierarchy. 
`SELECTED_CLASSES = [
    'cat', 'dog',
    'bus', 'truck',
    'elephant', 'horse',
    'bowl', 'cup',
    ]` 

```sh
cd dataset/
python generate_full_MetaShift.py
```

The following files will be generated by executing the script. Modify the global varaible `SUBPOPULATION_SHIFT_DATASET_FOLDER` to change the destination folder.  

```plain
/data/MetaShift/MetaDataset-full
├── cat/
    ├── cat(keyboard)/
    ├── cat(sink)/ 
    ├── ... 
├── dog/
    ├── dog(surfboard) 
    ├── dog(boat)/ 
    ├── ...
├── bus/ 
├── ...
```

In addition, to save storage, all copied images are symbolic links. You can set `use_symlink=True` in the code to perform actual file copying. If you really want to generate the **full** MetaShift, then set `ONLY_SELECTED_CLASSES = True` in [dataset/Constants.py](./dataset/Constants.py). 

"""
META_DATASET_FOLDER    = '/data/MetaShift/MetaDataset-full' # destination path

import pickle
import numpy as np
import json, re, math
from collections import Counter, defaultdict
from itertools import repeat
import pprint
import os, errno
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil # for copy files
import networkx as nx # graph vis
import pandas as pd
from sklearn.decomposition import TruncatedSVD

import Constants
IMAGE_DATA_FOLDER      = Constants.IMAGE_DATA_FOLDER

ONLY_SELECTED_CLASSES  = Constants.ONLY_SELECTED_CLASSES 

def parse_node_str(node_str):
    tag = node_str.split('(')[-1][:-1]
    subject_str = node_str.split('(')[0].strip() 
    return subject_str, tag


def load_candidate_subsets():
    pkl_save_path = "./meta_data/full-candidate-subsets.pkl" 
    with open(pkl_save_path, "rb") as pkl_f:
        load_data = pickle.load( pkl_f )
        print('pickle load', len(load_data), pkl_save_path)
        # pprint.pprint(load_data) # only for debugging 
        return load_data


##################################
# Copy Image Sets: Work at subject_str level
##################################
def copy_image_for_subject(root_folder, subject_str, subject_data, node_name_to_img_id, trainsg_dupes, use_symlink=True):


    ##################################
    # Iterate all the subsets of the given subject 
    ##################################
    for node_name in subject_data: 
        subject_str, tag = parse_node_str(node_name)

        ##################################
        # Create dataset a new folder 
        ##################################
        subject_localgroup_folder = os.path.join(root_folder, subject_str, node_name)
        if os.path.isdir(subject_localgroup_folder): 
            shutil.rmtree(subject_localgroup_folder) 
        os.makedirs(subject_localgroup_folder, exist_ok = False)

        for image_idx_in_set, imageID in enumerate(node_name_to_img_id[node_name] - trainsg_dupes): 

            src_image_path = IMAGE_DATA_FOLDER + imageID + '.jpg'
            dst_image_path = os.path.join(subject_localgroup_folder, imageID + '.jpg') 

            if use_symlink:
                ##################################
                # Image Copy Option B: create symbolic link
                # Usage: for local use, saving disk storge. 
                ##################################
                os.symlink(src_image_path, dst_image_path)
                # print('symlink:', src_image_path, dst_image_path)
            else:
                ##################################
                # Image Copy Option A: copy the whole jpg file
                # Usage: for sharing the meta-dataset
                ################################## 
                shutil.copyfile(src_image_path, dst_image_path)
                # print('copy:', src_image_path, dst_image_path)

    return 



##################################
# Graph visualization 
################################## 


def build_subset_graph(subject_most_common_list, node_name_to_img_id, trainsg_dupes, subject_str):
    N_sets = len(subject_most_common_list)
    # print('N_sets', N_sets)
    Adjacency_matrix = np.zeros((N_sets, N_sets))
    for i in range(N_sets):
        for j in range(i+1,N_sets):
            set_A = node_name_to_img_id[subject_most_common_list[i]] - trainsg_dupes
            set_B = node_name_to_img_id[subject_most_common_list[j]] - trainsg_dupes
            overlap_set = set_A.intersection(set_B)
            edge_weight = len(overlap_set) / min( len(set_A), len(set_B) )
            Adjacency_matrix[i,j] = Adjacency_matrix[j,i] = edge_weight
    # print('Adjacency_matrix', Adjacency_matrix)

    # Adjacency_matrix[Adjacency_matrix<0.25] = 0 # sparsify edges
    Adjacency_matrix[Adjacency_matrix<0.2] = 0 # sparsify edges

    labels = []
    for i, x in enumerate(subject_most_common_list):
        # add a \n
        labels.append(x.replace('(', '\n('))

    # G = nx.from_numpy_matrix(np.matrix(Adjacency_matrix), create_using=nx.Graph)
    A_pd = pd.DataFrame(np.matrix(Adjacency_matrix), index=labels, columns=labels)
    G = nx.from_pandas_adjacency(A_pd)
    # print('G', G, G.nodes)
    return G



def draw_subject_set_graph(subject_most_common_list, node_name_to_img_id, trainsg_dupes, subject_str):

    G = build_subset_graph(subject_most_common_list, node_name_to_img_id, trainsg_dupes, subject_str)

    import networkx.algorithms.community as nxcom

    # Find the communities
    communities = sorted(nxcom.greedy_modularity_communities(G), key=len, reverse=True)
    # Count the communities
    # print(f"The graph has {len(communities)} communities.")
    for community in communities:
        community_merged = set()
        for node_str in community:
            node_str = node_str.replace('\n', '')
            node_image_IDs = node_name_to_img_id[node_str]
            community_merged.update(node_image_IDs)
            # print(node_str , len(node_image_IDs), end=';')

        # print('total size:',len(community_merged))
        community_set = set([ x.replace('\n', '') for x in community])
        # print(community_set, '\n\n')


    def set_node_community(G, communities):
        '''Add community to node attributes'''
        for c, v_c in enumerate(communities):
            for v in v_c:
                # Add 1 to save 0 for external edges
                G.nodes[v]['community'] = c + 1


    def set_edge_community(G):
        '''Find internal edges and add their community to their attributes'''
        for v, w, in G.edges:
            if G.nodes[v]['community'] == G.nodes[w]['community']:
                # Internal edge, mark with community
                G.edges[v, w]['community'] = G.nodes[v]['community']
            else:
                # External edge, mark as 0
                G.edges[v, w]['community'] = 0

    def get_color(i, r_off=1, g_off=1, b_off=1):
        '''Assign a color to a vertex.'''
        r0, g0, b0 = 0, 0, 0
        # n = 16
        n = 32 # default running - not consiering edge weights 
        low, high = 0.1, 0.9
        span = high - low
        r = low + span * (((i + r_off) * 3) % n) / (n - 1)
        g = low + span * (((i + g_off) * 5) % n) / (n - 1)
        b = low + span * (((i + b_off) * 7) % n) / (n - 1)
        return (r, g, b)      

    # Set node and edge communities
    set_node_community(G, communities)
    set_edge_community(G)
    node_color = [get_color(G.nodes[v]['community']) for v in G.nodes]

    # Set community color for edges between members of the same community (internal) and intra-community edges (external)
    external = [(v, w) for v, w in G.edges if G.edges[v, w]['community'] == 0]
    internal = [(v, w) for v, w in G.edges if G.edges[v, w]['community'] > 0]
    internal_color = [get_color(G.nodes[v]['community']) for v,w in internal]


    karate_pos = nx.spring_layout(
        G=G, 
        seed=1234,
        )
    # plt.rcParams.update({'figure.figsize': (16, 12)}) 
    # plt.rcParams.update({'figure.figsize': (15, 10)}) 
    plt.rcParams.update({'figure.figsize': (22, 15)}) 
    # plt.rcParams.update({'figure.figsize': (25, 20)}) 

    # Draw external edges
    nx.draw_networkx(
        G,
        # font_size=12,
        width=0.5,
        pos=karate_pos,
        node_size=0,
        edgelist=external,
        edge_color="silver")


    # Draw nodes and internal edges
    nx.draw_networkx(
        G,
        alpha=0.75, 
        # with_labels=False,
        node_size=500,
        pos=karate_pos,
        node_color=node_color,
        edgelist=internal,
        edge_color=internal_color)
    
    ##################################
    # Visualize edge weights [optional]
    ##################################
    # edge_labels = nx.get_edge_attributes(G,'weight')
    # edge_labels = {k: round(v, 2) for k, v in edge_labels.items()}
    # nx.draw_networkx_edge_labels(G,karate_pos,edge_labels=edge_labels)
    plt.savefig('./meta-graphs/' + subject_str + '_graph.jpg', bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close('all') 
    return 


IMGAGE_SUBSET_SIZE_THRESHOULD = 25 

def preprocess_groups(output_files_flag=True, subject_classes = Constants.SELECTED_CLASSES):

    os.makedirs('./meta-graphs', exist_ok = True)


    trainsg_dupes = set()

    ##################################
    # Load cache data
    # Global data dict
    # Consult back to this dict for concrete image IDs. 
    ##################################
    node_name_to_img_id = load_candidate_subsets()



    ##################################
    # Build a default counter first 
    # Data Iteration
    ##################################
    group_name_counter = Counter()
    for node_name in node_name_to_img_id.keys():
        ##################################
        # Apply a threshould: e.g., 100
        ##################################
        imageID_set = node_name_to_img_id[node_name]
        imageID_set = imageID_set-trainsg_dupes
        node_name_to_img_id[node_name] = imageID_set
        if len(imageID_set) >= IMGAGE_SUBSET_SIZE_THRESHOULD:
            group_name_counter[node_name] = len(imageID_set)
        else:
            pass

    most_common_list = group_name_counter.most_common()
    # print("most_common_list", (most_common_list))

    most_common_list = [ (x, count) for x, count in group_name_counter.items()]

    ##################################
    # Build a subject dict 
    ##################################

    subject_group_summary_dict = defaultdict(Counter)
    for node_name, imageID_set_len in most_common_list:
        subject_str, tag = parse_node_str(node_name)
        ##################################
        # TMP: inspect certain class
        ##################################
        if ONLY_SELECTED_CLASSES and subject_str not in subject_classes:
            continue 

        subject_group_summary_dict[ subject_str ][ node_name ] = imageID_set_len

    ##################################
    # Print the subject dict stats
    ##################################
    subject_group_summary_list = sorted(subject_group_summary_dict.items(), key=lambda x:  sum(x[1].values()), reverse=True) 

    # pprint.pprint(subject_group_summary_list)
    # print('num suject class:', len(subject_group_summary_list))

    new_subject_group_summary_list = list()
    subjects_to_all_set = defaultdict(set)
    ##################################
    # Subject filtering for dataset generation
    ##################################
    for subject_str, subject_data in subject_group_summary_list:

        ##################################
        # Discard an object class if it has too few local groups
        ##################################
        if len(subject_data) <= 5:
        # if len(subject_data) <= 10:
            continue
        else:
            new_subject_group_summary_list.append((subject_str, subject_data))
        
        ##################################
        # Copy Files
        # This is for dataset sharing purpose. not splitting sets. 
        ##################################
        # print('subject_data', subject_data)

        if output_files_flag:
            draw_subject_set_graph(sorted(subject_data.keys()), node_name_to_img_id, trainsg_dupes, subject_str)
            copy_image_for_subject(META_DATASET_FOLDER, subject_str, subject_data, node_name_to_img_id, trainsg_dupes, use_symlink=True) # use False to share 

        ##################################
        # Iterate all the subsets of the given subject 
        ##################################
        for node_name in subject_data: 
            subject_str, tag = parse_node_str(node_name)
            subjects_to_all_set[subject_str].update(node_name_to_img_id[node_name])

    # if output_files_flag:
    pprint.pprint(new_subject_group_summary_list)

    print('Done! Please check ', META_DATASET_FOLDER)

    return node_name_to_img_id, most_common_list, subjects_to_all_set, subject_group_summary_dict

if __name__ == '__main__':
    preprocess_groups(output_files_flag=True)

