"""
Generate MetaDataset with train/test split 

"""

CUSTOM_SPLIT_DATASET_FOLDER = '/data/GQA/MetaDataset-0924-Bus-Truck'
# cd /data/GQA/
# zip -r MetaDataset-0924.zip MetaDataset-0924-*
# 
from networkx.algorithms.traversal.depth_first_search import dfs_edges
import pandas as pd 
import seaborn as sns

import pickle
import numpy as np
import json, re, math
from collections import Counter, defaultdict
from itertools import repeat
import pprint
import os, errno
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil # for copy files
import networkx as nx # graph vis
import pandas as pd
from sklearn.decomposition import TruncatedSVD

import Constants
IMAGE_DATA_FOLDER          = Constants.IMAGE_DATA_FOLDER

from generate_full_metadataset import preprocess_groups, build_subset_graph, copy_image_for_subject


def print_communities(subject_data, node_name_to_img_id, trainsg_dupes, subject_str):
    ##################################
    # Community detection 
    ##################################
    G = build_subset_graph(subject_data, node_name_to_img_id, trainsg_dupes, subject_str)

    import networkx.algorithms.community as nxcom

    # Find the communities
    communities = sorted(nxcom.greedy_modularity_communities(G), key=len, reverse=True)
    # Count the communities
    print(f"The graph has {len(communities)} communities.")
    for community in communities:
        community_merged = set()
        for node_str in community:
            node_str = node_str.replace('\n', '')
            node_image_IDs = node_name_to_img_id[node_str]
            community_merged.update(node_image_IDs)
            # print(node_str , len(node_image_IDs), end=';')

        print('total size:',len(community_merged))
        community_set = set([ x.replace('\n', '') for x in community])
        print(community_set, '\n\n')
    return G



def parse_dataset_scheme(dataset_scheme, node_name_to_img_id, exclude_img_id=set(), split='test'):
    """
    exclude_img_id contains both trainsg_dupes and test images that we do not want to leak 
    """
    community_name_to_img_id = defaultdict(set)
    all_img_id = set()

    ##################################
    # Iterate subject_str: e.g., cat
    ##################################
    for subject_str in dataset_scheme:        
        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        for community_name in dataset_scheme[subject_str]:
            ##################################
            # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
            ##################################
            for node_name in dataset_scheme[subject_str][community_name]:
                community_name_to_img_id[community_name].update(node_name_to_img_id[node_name] - exclude_img_id)
                all_img_id.update(node_name_to_img_id[node_name] - exclude_img_id)
            print(community_name, 'Size:', len(community_name_to_img_id[community_name]) )


        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        root_folder = os.path.join(CUSTOM_SPLIT_DATASET_FOLDER, split)
        copy_image_for_subject(root_folder, subject_str, dataset_scheme[subject_str], community_name_to_img_id, trainsg_dupes=set(), use_symlink=False) # use False to share 

    return community_name_to_img_id, all_img_id


def get_all_nodes_in_dataset(dataset_scheme):
    all_nodes = set()
    ##################################
    # Iterate subject_str: e.g., cat
    ##################################
    for subject_str in dataset_scheme:        
        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        for community_name in dataset_scheme[subject_str]:
            ##################################
            # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
            ##################################
            for node_name in dataset_scheme[subject_str][community_name]:
                all_nodes.add(node_name)
    return all_nodes

def generate_splitted_metadaset():

    if os.path.isdir(CUSTOM_SPLIT_DATASET_FOLDER): 
        shutil.rmtree(CUSTOM_SPLIT_DATASET_FOLDER) 
    os.makedirs(CUSTOM_SPLIT_DATASET_FOLDER, exist_ok = False)

    subject_classes = ['bus', 'truck']
    node_name_to_img_id, most_common_list, subjects_to_all_set, subject_group_summary_dict = preprocess_groups(output_files_flag=False, subject_classes = subject_classes)


    ##################################
    # Removing ambiguous images that have both cats and dogs 
    ##################################
    trainsg_dupes = node_name_to_img_id['truck(bus)'] 
    subject_str_to_Graphs = dict()
    
    for subject_str in subject_classes:
        subject_data = [ x for x in subject_group_summary_dict[subject_str].keys() if x not in ['truck(bus)', 'bus(truck)']  ]
        # print('subject_data', subject_data)
        ##################################
        # Print detected communities in Meta-Graph
        ##################################
        G = print_communities(subject_data, node_name_to_img_id, trainsg_dupes, subject_str) # print detected communities, which guides us the train/test split. 
        subject_str_to_Graphs[subject_str] = G


    train_set_scheme = {
        'bus': {
            'bus(tower)': {'bus(clock)', 'bus(tower)', 'bus(bridge)', 'bus(cone)', 'bus(lamp)', 'bus(bench)'},
            'bus(traffic light)': {'bus(traffic light)', 'bus(street light)', 'bus(fire hydrant)'}, 
        }, 
        'truck': {
            'truck(cone)': {'truck(cones)', 'truck(cone)', 'truck(fire hydrant)', },  
            'truck(fence)': {'truck(horse)', 'truck(fence)'}, 

            'truck(bike)': {'truck(motorcycle)', 'truck(bicycle)', 'truck(bike)', 'truck(helmet)'}, 
            'truck(mirror)': {'truck(taxi)', 'truck(van)', 'truck(mirror)'} , 

            'truck(flag)': {'truck(flag)', 'truck(american flag)'}, 
            'truck(tower)': {'truck(clock)', 'truck(tower)', 'truck(bench)', 'truck(lamp)'}, 

            'truck(traffic light)': {'truck(traffic light)'}, 
            'truck(dog)': {'truck(dog)', 'truck(bed)',} 
        }
    }

    test_set_scheme = {
        'bus': {
            'bus(fence)': {'bus(statue)', 'bus(fence)',}, 
            'bus(bridge)': {'bus(water)', 'bus(bridge)', 'bus(train)'},
            'bus(house)': {'bus(house)', 'bus(chimney)', }, 

            'bus(driver)': {'bus(bus driver)', 'bus(driver)', 'bus(passengers)', 'bus(passenger)', }, 
        },
        'truck': {
            'truck(airplane)': {'truck(airplane)'}, 
            'truck(boat)': {'truck(trailer)', 'truck(boat)', }, # 
        },
    }

    additional_test_set_scheme = {
        'bus': {
            'bus(bike)': {'bus(bike)', 'bus(bicycle)', 'bus(motorcycle)', 'bus(helmet)', 'bus(horse)'}, 
            'bus(sign)': {'bus(sign)'},             
            'bus(woman)': {'bus(woman)', 'bus(lady)', 'bus(bag)', 'bus(cell phone)', 'bus(glasses)',}, 
            'bus(car)': {'bus(car)', 'bus(cars)', 'bus(suv)'},
            'bus(pole)': {'bus(pole)'}, 
            'bus(tree)': {'bus(tree)', 'bus(trees)', 'bus(palm tree)'}
        },
        'truck': {
            'truck(sign)': {'truck(sign)'}, 
            'truck(grass)': {'truck(grass)'}, 
            'truck(car)': {'truck(car)', 'truck(cars)'}, 
            'truck(train)': {'truck(train)', 'truck(bridge)', 'truck(sign)'}, 
            'truck(tree)': {'truck(tree)', 'truck(trees)'},
            'truck(woman)': {'truck(woman)', 'truck(lady)', 'truck(girl)', 'truck(purse)', 'truck(umbrella)', 'truck(bag)'},
            'truck(chair)': {'truck(woman)', 'truck(table)', 'truck(chair)', 'truck(cell phone)'},
        },
    }

    # TODO: add distance measurement. 
    def layout_group_geometry(subject_str_to_Graphs, train_set_scheme, test_set_scheme, additional_test_set_scheme):
        subject_community_name_to_pos = defaultdict(lambda: defaultdict(list))
        for dataset_scheme in [train_set_scheme, test_set_scheme, additional_test_set_scheme]: 
            ##################################
            # Iterate subject_str: e.g., cat
            ##################################
            for subject_str in dataset_scheme:        
                G = subject_str_to_Graphs[subject_str]
                karate_pos = nx.spring_layout(
                    G=G, 
                    seed=1234,
                    dim=10,
                    )
                # print('karate_pos', karate_pos)

                ##################################
                # Iterate community_name: e.g., cat(sofa)
                ##################################
                for community_name in dataset_scheme[subject_str]:
                    community_node_pos_list = []
                    ##################################
                    # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
                    ##################################
                    for node_name in dataset_scheme[subject_str][community_name]:
                        node_pos = karate_pos[node_name.replace('(', '\n(')] # numpy.ndarray 
                        community_node_pos_list.append(node_pos)
                        
                    community_pos = np.mean(community_node_pos_list, axis=0)
                    subject_community_name_to_pos[subject_str]['name'].append(community_name)
                    subject_community_name_to_pos[subject_str]['community_pos'].append(community_pos)

        ##################################
        # Visualize 
        ##################################
        for subject_str in subject_community_name_to_pos.keys():
            # 1. distance query
            # 2. visualize 
            df = pd.DataFrame(subject_community_name_to_pos[subject_str])
            # print(df)
            from sklearn.decomposition import PCA
            pca = PCA(n_components=2)
            pca_result = pca.fit_transform(np.vstack(df['community_pos'].values))
            df['pca_one'] = pca_result[:,0]
            df['pca_two'] = pca_result[:,1] 
            sns.set(font_scale=2)  # crazy big
            ax = sns.lmplot(
                    x='pca_one', # Horizontal axis
                    y='pca_two', # Vertical axis
                    data=df, # Data source
                    fit_reg=False, # Don't fix a regression line
                    height = 10,
                    aspect =2 ) # size and dimension

            def label_point(x, y, val, ax):
                a = pd.concat({'x': x, 'y': y, 'val': val}, axis=1)
                for i, point in a.iterrows():
                    ax.text(point['x']+.02, point['y'], str(point['val']))

            label_point(df.pca_one, df.pca_two, df.name, plt.gca())
            plt.savefig('./custom-community-geometry/' + subject_str + '_graph.jpg', bbox_inches='tight', pad_inches=0, dpi=100)
            plt.close('all') 

        return 


    layout_group_geometry(subject_str_to_Graphs, train_set_scheme, test_set_scheme, additional_test_set_scheme)

    print('========== test set info ==========')
    test_community_name_to_img_id, test_all_img_id = parse_dataset_scheme(test_set_scheme, node_name_to_img_id, exclude_img_id=trainsg_dupes, split='test')
    # print('test_all_img_id', len(test_all_img_id))
    print('========== train set info ==========')
    train_community_name_to_img_id, train_all_img_id = parse_dataset_scheme(train_set_scheme, node_name_to_img_id, exclude_img_id=test_all_img_id.union(trainsg_dupes), split='train')
    print('========== additional test set info ==========')
    additional_test_community_name_to_img_id, additional_test_all_img_id = parse_dataset_scheme(additional_test_set_scheme, node_name_to_img_id, exclude_img_id=train_all_img_id.union(trainsg_dupes), split='test')

    for subject_str in subject_classes:
        subject_data = subject_group_summary_dict[subject_str].keys()
        
        print('[Nodes Left]', 
            subject_data - get_all_nodes_in_dataset(test_set_scheme) - get_all_nodes_in_dataset(train_set_scheme) - get_all_nodes_in_dataset(additional_test_set_scheme)
        )

    ##################################
    # Trying to calcualte the distance metric. 
    ##################################
    return

if __name__ == '__main__':
    generate_splitted_metadaset()

