import os
import pandas as pd
import glob
import torch
from torchvision import transforms
from util import seed_worker
from dataset import df_to_dict, balance_stat, dict_to_stat, split_df, ActionDataset
from visualize import show_stat, show_split

IID_NOUNS = [
    'meat', 'pan', 'potato', 'salt', 'sauce', 'hob', 'lid', 'tap', 'egg',
    'bottle', 'cupboard', 'dishwasher', 'drawer', 'fridge', 'package', 'spice',
    'aubergine', 'bacon', 'carrot', 'mushroom', 'omelette', 'onion', 'tomato',
    'broccoli', 'dough', 'plate', 'apple', 'food', 'salad', 'bowl', 'coffee',
    'jar', 'rack:drying', 'sink', 'liquid', 'paper', 'knife', 'liquid:washing',
    'banana', 'tray', 'mixture', 'onion:spring', 'seed'
]
OOD_NOUNS = [
    'oil', 'bag', 'oven', 'kettle', 'chicken', 'garlic', 'bin', 'box',
    'container', 'bread', 'cheese', 'courgette', 'cucumber', 'olive',
    'peach', 'pepper', 'pizza', 'sausage', 'pasta', 'rice', 'spoon',
    'cup', 'glass', 'pot', 'board:chopping', 'mat', 'squash', 'leaf'
]


class DataLoaderFactory:
    @staticmethod
    def get_data_loaders(dataset, root, ood, seed, train_size, rebalance=True):

        # load meta data
        if dataset == 'procthor':
            '''
                procthor: ['scene', 'idx', 'figure', 'noun_class', 'verb_class', 'xmin', 'ymin', 'xmax', 'ymax']
            '''
            filecsv = f'{root}/annotations.csv'
            if os.path.exists(filecsv):
                df = pd.read_csv(filecsv)
                print(f'Loaded annotations from {filecsv}')
            else:
                files = glob.glob(f'{root}/proc_*/annotations.csv')
                files.sort()
                stack = list()
                for file in files:
                    df = pd.read_csv(file, header=None, names=['scene', 'idx', 'figure', 'noun_class', 'verb_class', 'xmin', 'ymin', 'xmax', 'ymax'])
                    stack.append(df)
                df = pd.concat(stack, ignore_index=True)
                df.to_csv(filecsv, index=False)

            df = df[df['verb_class'] != 'none']
            df = df[df['verb_class'] != 'cook']

        elif dataset == 'epickitchens':
            df = pd.read_csv(f'{root}/valid_image_pairs_metadata.csv')
            
            verb_counts = df['verb_class'].value_counts()
            df = df[df['verb_class'].isin(verb_counts[verb_counts >= 10].index)]
            
            df_iid = df[df['noun_class'].isin(IID_NOUNS)]
            df_ood = df[df['noun_class'].isin(OOD_NOUNS)]
            
            df_iid = df_iid.reset_index(drop=True)
            df_ood = df_ood.reset_index(drop=True)
            
            df = pd.concat([df_iid, df_ood], ignore_index=True)
        else:
            raise NotImplementedError

        # attributes
        dict_noun_index = {k: v for v, k in enumerate(df['noun_class'].unique())}
        dict_noun_class = {v: k for v, k in enumerate(df['noun_class'].unique())}
        dict_verb_index = {k: v for v, k in enumerate(df['verb_class'].unique())}
        dict_verb_class = {v: k for v, k in enumerate(df['verb_class'].unique())}
        df['noun_index'] = df.apply(lambda row: dict_noun_index[row.noun_class], axis=1)
        df['verb_index'] = df.apply(lambda row: dict_verb_index[row.verb_class], axis=1)

        num_instance = len(df)
        num_noun = len(df['noun_class'].unique())
        num_verb = len(df['verb_class'].unique())
        print(f'Dataset stat: # instance {num_instance}, # noun {num_noun}, # verb {num_verb}')

        # symmetry
        if dataset == 'procthor':
            # pdb.set_trace()
            from procthor.action import action_symmetry
            symmetric_verb_class = action_symmetry()
            symmetric_verb_index = {dict_verb_index[k]: dict_verb_index[v] for k, v in symmetric_verb_class.items()}
        else:
            symmetric_verb_index = None

        # rebalance data
        dict_verb, dict_noun, dict_verb_noun = df_to_dict(df)
        stat_verb = dict_to_stat(dict_verb)

        # feasible combinations
        bool_verb_noun = torch.zeros((num_verb, num_noun)).bool()
        for (verb, noun) in dict_verb_noun.keys():
            bool_verb_noun[(dict_verb_index[verb], dict_noun_index[noun])] = True

        dict_verb = balance_stat(dict_verb, stat_verb)
        stat_verb = dict_to_stat(dict_verb)

        if rebalance:
            indices = [name for names in dict_verb.values() for name in names]
            print(f'{len(indices)} / {len(df)} instances are kept from rebalance')
            df = df[df.index.isin(indices)].reset_index(drop=True)


        if dataset == 'epickitchens':
            df_iid = df[df['noun_class'].isin(IID_NOUNS)]
            df_ood = df[df['noun_class'].isin(OOD_NOUNS)]
        else:
            df_iid, df_ood = split_df(df, axis=ood, seed=seed)
        dict_verb_iid, dict_noun_iid, dict_verb_noun_iid = df_to_dict(df_iid)
        dict_verb_ood, dict_noun_ood, dict_verb_noun_ood = df_to_dict(df_ood)

        show_split(dict_verb_noun_iid, dict_verb_noun_ood, figname=f'syst_split.svg')

        # split
        if dataset == 'procthor':
            if train_size == 1000:
                max_num_ood = min(5000, int(0.5 * len(df_iid))) // 14
            else:
                max_num_ood = min(5000, int(0.5 * len(df_iid)))
        else:
            max_num_ood = min(5000, int(0.5 * len(df_iid)))
            
        num_valid = max(min(len(df_ood), max_num_ood), 1)

        df_iid = df_iid.sample(frac=1)  # shuffle order
        df_train = df_iid[num_valid:]
        df_test = df_iid[:num_valid]
        
        # ood validation set for model selection
        if len(df_ood) < num_valid * 2:
            print("duplication between ood validation and ood test")
        else:
            print("disjoint ood validation and test set")
        df_ood = df_ood.sample(frac=1)
        df_valid = df_ood[-num_valid:]
        df_ood = df_ood[:num_valid]

        df_train = df_train[:train_size]

        def print_combinations_count(df, name):
            combinations = df.groupby(['verb_class', 'noun_class']).size()
            print(f'{name} dataset has {len(combinations)} unique verb-noun combinations.')
            # for (verb, noun), count in combinations.items():
            #     print(f'Combination (Verb: {verb}, Noun: {noun}) has {count} samples in {name} dataset.')

        print_combinations_count(df_train, "Train")
        print_combinations_count(df_test, "Test")
        print_combinations_count(df_valid, "Validation")
        print_combinations_count(df_ood, "OOD")

        return df_train, df_test, df_valid, df_ood, dict_verb_index, dict_noun_index
