"""
Generate MetaDataset with train/test split 

"""

CUSTOM_SPLIT_DATASET_FOLDER = '/data/GQA/MetaDataset-Cat-Dog-indoor-outdoor'
# 
import pandas as pd 
import seaborn as sns

import pickle
import numpy as np
import json, re, math
from collections import Counter, defaultdict
from itertools import repeat
import pprint
import os, errno
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil # for copy files
import networkx as nx # graph vis
import pandas as pd
from sklearn.decomposition import TruncatedSVD

import Constants
IMAGE_DATA_FOLDER          = Constants.IMAGE_DATA_FOLDER

from generate_full_metadataset import preprocess_groups, build_subset_graph, copy_image_for_subject


def print_communities(subject_data, node_name_to_img_id, trainsg_dupes, subject_str):
    ##################################
    # Community detection 
    ##################################
    G = build_subset_graph(subject_data, node_name_to_img_id, trainsg_dupes, subject_str)

    import networkx.algorithms.community as nxcom

    # Find the communities
    communities = sorted(nxcom.greedy_modularity_communities(G), key=len, reverse=True)
    # Count the communities
    print(f"The graph has {len(communities)} communities.")
    for community in communities:
        community_merged = set()
        for node_str in community:
            node_str = node_str.replace('\n', '')
            node_image_IDs = node_name_to_img_id[node_str]
            community_merged.update(node_image_IDs)
            # print(node_str , len(node_image_IDs), end=';')

        print('total size:',len(community_merged))
        community_set = set([ x.replace('\n', '') for x in community])
        print(community_set, '\n\n')
    return G 



def parse_dataset_scheme(dataset_scheme, node_name_to_img_id, exclude_img_id=set(), split='test'):
    """
    exclude_img_id contains both trainsg_dupes and test images that we do not want to leak 
    """
    community_name_to_img_id = defaultdict(set)
    all_img_id = set()

    ##################################
    # Iterate subject_str: e.g., cat
    ##################################
    for subject_str in dataset_scheme:        
        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        for community_name in dataset_scheme[subject_str]:
            ##################################
            # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
            ##################################
            for node_name in dataset_scheme[subject_str][community_name]:
                community_name_to_img_id[community_name].update(node_name_to_img_id[node_name] - exclude_img_id)
                all_img_id.update(node_name_to_img_id[node_name] - exclude_img_id)
            print(community_name, 'Size:', len(community_name_to_img_id[community_name]) )


        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        root_folder = os.path.join(CUSTOM_SPLIT_DATASET_FOLDER, split)
        copy_image_for_subject(root_folder, subject_str, dataset_scheme[subject_str], community_name_to_img_id, trainsg_dupes=set(), use_symlink=False) # use False to share 

    return community_name_to_img_id, all_img_id


def get_all_nodes_in_dataset(dataset_scheme):
    all_nodes = set()
    ##################################
    # Iterate subject_str: e.g., cat
    ##################################
    for subject_str in dataset_scheme:        
        ##################################
        # Iterate community_name: e.g., cat(sofa)
        ##################################
        for community_name in dataset_scheme[subject_str]:
            ##################################
            # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
            ##################################
            for node_name in dataset_scheme[subject_str][community_name]:
                all_nodes.add(node_name)
    return all_nodes

def generate_splitted_metadaset():

    if os.path.isdir(CUSTOM_SPLIT_DATASET_FOLDER): 
        shutil.rmtree(CUSTOM_SPLIT_DATASET_FOLDER) 
    os.makedirs(CUSTOM_SPLIT_DATASET_FOLDER, exist_ok = False)


    node_name_to_img_id, most_common_list, subjects_to_all_set, subject_group_summary_dict = preprocess_groups(output_files_flag=False)

    ##################################
    # Removing ambiguous images that have both cats and dogs 
    ##################################
    trainsg_dupes = node_name_to_img_id['cat(dog)'] # can also use 'dog(cat)'
    subject_str_to_Graphs = dict()


    for subject_str in ['cat', 'dog']:
        subject_data = [ x for x in subject_group_summary_dict[subject_str].keys() if x not in ['cat(dog)', 'dog(cat)'] ]
        # print('subject_data', subject_data)
        ##################################
        # Print detected communities in Meta-Graph
        ##################################
        G = print_communities(subject_data, node_name_to_img_id, trainsg_dupes, subject_str) # print detected communities, which guides us the train/test split. 
        subject_str_to_Graphs[subject_str] = G




    test_set_scheme = {
        'cat': {
            'cat(outdoor)': {
                'cat(car)',
                'cat(fence)', 'cat(grass)', 'cat(roof)', 'cat(bench)', 'cat(bird)', 'cat(house)', 
            },
        },
        'dog': {
            'dog(indoor)': {
                'dog(screen)', 'dog(shelf)', 'dog(desk)', 'dog(picture)', 'dog(laptop)',
                'dog(remote control)', 'dog(blanket)', 'dog(bed)', 'dog(sheet)', 'dog(lamp)', 'dog(books)', 'dog(pillow)', 'dog(curtain)', 
                'dog(container)', 'dog(table)', 'dog(cup)', 'dog(plate)', 'dog(food)', 'dog(box)',
                'dog(rug)', 'dog(floor)', 'dog(cabinet)', 'dog(towel)',
                'dog(bowl)',
                'dog(television)', 'dog(carpet)',
                'dog(sofa)',

            },    
        },    
    }

    train_set_scheme = {
        'cat': {
            'cat(indoor)': {
                'cat(speaker)', 'cat(computer)', 'cat(screen)', 'cat(laptop)', 'cat(computer mouse)', 'cat(keyboard)', 'cat(monitor)', 'cat(desk)',
                'cat(sheet)', 'cat(bed)', 'cat(blanket)', 'cat(remote control)', 'cat(comforter)', 'cat(pillow)', 'cat(couch)',
                'cat(books)', 'cat(book)', 'cat(television)', 'cat(bookshelf)', 'cat(blinds)',
                'cat(sink)', 'cat(bottle)', 'cat(faucet)', 'cat(towel)', 'cat(counter)',
                'cat(curtain)', 'cat(toilet)', 'cat(pot)', 
                'cat(carpet)', 'cat(toy)', 'cat(floor)',
                'cat(plate)', 'cat(rug)', 'cat(food)', 'cat(table)',
                'cat(box)', 'cat(paper)', 'cat(suitcase)', 'cat(bag)',
                'cat(container)', 'cat(vase)', 'cat(shelf)', 'cat(bowl)',
                'cat(picture)', 'cat(papers)', 'cat(lamp)',
                'cat(cup)', 'cat(sofa)', 
            },
        },
        'dog': {
            'dog(outdoor)': {
                'dog(house)', 'dog(grass)', 'dog(horse)', 'dog(fence)', 'dog(cow)', 'dog(sheep)', 'dog(dirt)',
                'dog(car)', 'dog(motorcycle)', 'dog(truck)', 'dog(helmet)', 
                'dog(snow)',
                'dog(flag)', 'dog(boat)', 'dog(rope)', 'dog(trees)', 'dog(frisbee)',
                'dog(bike)', 'dog(bicycle)', 
                'dog(sand)', 'dog(surfboard)', 'dog(water)', 
                'dog(fire hydrant)', 'dog(pole)', 
                'dog(skateboard)',
                'dog(bench)', 'dog(trash can)',
            },
        },
        
    }
    additional_test_set_scheme = dict() # empty dict 

    # TODO: add distance measurement. 
    def layout_group_geometry(subject_str_to_Graphs, train_set_scheme, test_set_scheme, additional_test_set_scheme):
        subject_community_name_to_pos = defaultdict(lambda: defaultdict(list))
        for dataset_scheme in [train_set_scheme, test_set_scheme, additional_test_set_scheme]: 
            ##################################
            # Iterate subject_str: e.g., cat
            ##################################
            for subject_str in dataset_scheme:        
                G = subject_str_to_Graphs[subject_str]
                karate_pos = nx.spring_layout(
                    G=G, 
                    seed=1234,
                    dim=10,
                    )
                # print('karate_pos', karate_pos)

                ##################################
                # Iterate community_name: e.g., cat(sofa)
                ##################################
                for community_name in dataset_scheme[subject_str]:
                    community_node_pos_list = []
                    ##################################
                    # Iterate node_name: e.g., 'cat(cup)', 'cat(sofa)', 'cat(chair)'
                    ##################################
                    for node_name in dataset_scheme[subject_str][community_name]:
                        node_pos = karate_pos[node_name.replace('(', '\n(')] # numpy.ndarray 
                        community_node_pos_list.append(node_pos)
                        
                    community_pos = np.mean(community_node_pos_list, axis=0)
                    subject_community_name_to_pos[subject_str]['name'].append(community_name)
                    subject_community_name_to_pos[subject_str]['community_pos'].append(community_pos)

        ##################################
        # Visualize 
        ##################################
        for subject_str in subject_community_name_to_pos.keys():
            # 1. distance query
            # 2. visualize 
            df = pd.DataFrame(subject_community_name_to_pos[subject_str])
            # print(df)
            from sklearn.decomposition import PCA
            pca = PCA(n_components=2)
            pca_result = pca.fit_transform(np.vstack(df['community_pos'].values))
            df['pca_one'] = pca_result[:,0]
            df['pca_two'] = pca_result[:,1] 
            sns.set(font_scale=2)  # crazy big
            ax = sns.lmplot(
                    x='pca_one', # Horizontal axis
                    y='pca_two', # Vertical axis
                    data=df, # Data source
                    fit_reg=False, # Don't fix a regression line
                    height = 10,
                    aspect =2 ) # size and dimension

            def label_point(x, y, val, ax):
                a = pd.concat({'x': x, 'y': y, 'val': val}, axis=1)
                for i, point in a.iterrows():
                    ax.text(point['x']+.02, point['y'], str(point['val']))

            label_point(df.pca_one, df.pca_two, df.name, plt.gca())
            plt.savefig('./custom-community-geometry/' + subject_str + '_graph.jpg', bbox_inches='tight', pad_inches=0, dpi=100)
            plt.close('all') 

        return 


    layout_group_geometry(subject_str_to_Graphs, train_set_scheme, test_set_scheme, additional_test_set_scheme)


    print('========== test set info ==========')
    test_community_name_to_img_id, test_all_img_id = parse_dataset_scheme(test_set_scheme, node_name_to_img_id, exclude_img_id=trainsg_dupes, split='test')
    # print('test_all_img_id', len(test_all_img_id))
    print('========== train set info ==========')
    train_community_name_to_img_id, train_all_img_id = parse_dataset_scheme(train_set_scheme, node_name_to_img_id, exclude_img_id=test_all_img_id.union(trainsg_dupes), split='train')
    # print('========== additional test set info ==========')
    # additional_test_community_name_to_img_id, additional_test_all_img_id = parse_dataset_scheme(additional_test_set_scheme, node_name_to_img_id, exclude_img_id=train_all_img_id.union(trainsg_dupes), split='test')

    for subject_str in ['cat', 'dog']:
        subject_data = subject_group_summary_dict[subject_str].keys()
        print('[Nodes Left]', 
            subject_data - get_all_nodes_in_dataset(test_set_scheme) - get_all_nodes_in_dataset(train_set_scheme) - get_all_nodes_in_dataset(additional_test_set_scheme)
        )

    """
    ** Simulating subpopulation shifts ** 

    TODO: re-export dataset for domain generalization algorithms. 
    parameter 1: ratio: hardness of subpopulation shift 
    parameter 2: name of the output dir 

    the structure of the dataset dir should look as this: 
        train/
            cat/
                ...images
            dog/
                ...images
        test/
            cat/
                ...images
            dog/
                ...images
        imageID_to_group.pkl -- specifying group information. 
    """

    SUBPOPULATION_SHIFT_DATASET_FOLDER = '/data/GQA/MetaDataset-subpopulation-shift'
    if os.path.isdir(SUBPOPULATION_SHIFT_DATASET_FOLDER): 
        shutil.rmtree(SUBPOPULATION_SHIFT_DATASET_FOLDER) 
    os.makedirs(SUBPOPULATION_SHIFT_DATASET_FOLDER, exist_ok = False)

    import random

    def shuffle_and_truncate(img_id_set, truncate=500):
        img_id_list = sorted(img_id_set)
        random.Random(42).shuffle(img_id_list)
        return img_id_list[:truncate]

    cat_outdoor_images = shuffle_and_truncate(test_community_name_to_img_id['cat(outdoor)'], 294) # 150 for test, 0-150 for training
    dog_indoor_images = shuffle_and_truncate(test_community_name_to_img_id['dog(indoor)'], 294) 
    cat_indoor_images = shuffle_and_truncate(train_community_name_to_img_id['cat(indoor)'], 994) # 150 for test, 450 for training, 
    dog_outdoor_images = shuffle_and_truncate(train_community_name_to_img_id['dog(outdoor)'], 994) 

    with open(SUBPOPULATION_SHIFT_DATASET_FOLDER + '/' + 'imageID_to_group.pkl', 'wb') as handle:
        imageID_to_group = dict()
        group_to_imageID = {
            'cat(outdoor)': cat_outdoor_images,
            'dog(indoor)': dog_indoor_images,
            'cat(indoor)': cat_indoor_images,
            'dog(outdoor)': dog_outdoor_images,
        }
        for group_str in group_to_imageID:
            for imageID in group_to_imageID[group_str]:
                imageID_to_group[imageID] = [group_str] 
        pickle.dump(imageID_to_group, file=handle)


    from sklearn.model_selection import train_test_split
    cat_outdoor_train, cat_outdoor_test, dog_indoor_train, dog_indoor_test = train_test_split(cat_outdoor_images, dog_indoor_images, test_size=144, random_state=42)
    cat_indoor_train, cat_indoor_test, dog_outdoor_train, dog_outdoor_test = train_test_split(cat_indoor_images, dog_outdoor_images, test_size=144, random_state=42)

    
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'val_out_of_domain', 'cat', use_symlink=True,
        img_IDs = cat_outdoor_test + cat_indoor_test,
        )
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'val_out_of_domain', 'dog', use_symlink=True,
        img_IDs = dog_indoor_test + dog_outdoor_test,
        )
    
    # plan: 800 training, 
    NUM_MINORITY_IMG = 100 # 10-150: 10, 50, 100, 150
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'train', 'cat', use_symlink=True,
        img_IDs = cat_outdoor_train[:NUM_MINORITY_IMG] + cat_indoor_train[NUM_MINORITY_IMG:],
        )
    copy_images(
        SUBPOPULATION_SHIFT_DATASET_FOLDER, 'train', 'dog', use_symlink=True,
        img_IDs = dog_indoor_train[:NUM_MINORITY_IMG] + dog_outdoor_train[NUM_MINORITY_IMG:],
        )
    print('NUM_MINORITY_IMG', NUM_MINORITY_IMG)
    ##################################
    # Copy Images: using only IDs
    ##################################


    return


##################################
# Copy Images specified by IDs: 
# destination folder: os.path.join(root_folder, split, subset_str)
##################################
def copy_images(root_folder,  split, subset_str, img_IDs, use_symlink=True):
    ##################################
    # Create dataset a new folder 
    ##################################
    subject_localgroup_folder = os.path.join(root_folder, split, subset_str)
    if os.path.isdir(subject_localgroup_folder): 
        shutil.rmtree(subject_localgroup_folder) 
    os.makedirs(subject_localgroup_folder, exist_ok = False)

    for image_idx_in_set, imageID in enumerate(img_IDs): 

        src_image_path = IMAGE_DATA_FOLDER + imageID + '.jpg'
        dst_image_path = os.path.join(subject_localgroup_folder, imageID + '.jpg') 

        if use_symlink:
            ##################################
            # Image Copy Option B: create symbolic link
            # Usage: for local use, saving disk storge. 
            ##################################
            os.symlink(src_image_path, dst_image_path)
            # print('symlink:', src_image_path, dst_image_path)
        else:
            ##################################
            # Image Copy Option A: copy the whole jpg file
            # Usage: for sharing the meta-dataset
            ################################## 
            shutil.copyfile(src_image_path, dst_image_path)
            # print('copy:', src_image_path, dst_image_path)

    return 


if __name__ == '__main__':
    generate_splitted_metadaset()

