import itertools
import numpy as np
from ICR import ICR
###############
with open("train_feat.npy", 'rb') as f: # Point this path to the OPL Training Features
    opl_features_train  = np.load(f)

with open("train_lab.npy", 'rb') as f:  # Point this path to the OPL Training Labels
    opl_labels_train   = np.load(f)
    
with open("test_feat.npy", 'rb') as f:  # Point this path to the OPL Test Features
    opl_features_test  = np.load(f)

with open("test_lab.npy", 'rb') as f:   # Point this path to the OPL Test Labels
    opl_labels_test  = np.load(f)
    
    
##########
ClassLabels_list = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
  'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
  'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
  'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
  'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle',
  'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
  'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon',
  'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
  'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone',
  'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree',
  'wolf', 'woman', 'worm']

##########
all_pairs = list(itertools.combinations(ClassLabels_list,2))
   
##########
from collections import defaultdict
dictOfEmb = defaultdict(list)
   
##########
for idx, Class in enumerate(ClassLabels_list):
    class_idx = np.where(opl_labels_train == idx)[0]
    dictOfEmb[Class] = np.take(opl_features_train, class_idx, axis=0)
  
##########
all_wordsVocab_class = [("word-" + str(a+1), ClassLabels_list[int(b)]) for a, b in zip(range(len(opl_features_train)), opl_labels_train)]
   
from collections import defaultdict
ClassVocab = defaultdict(list)
for class_word, classlabel in all_wordsVocab_class:
    ClassVocab[classlabel].append(class_word)
    
##########
iter_ICR = 5
    
def get_ICR_features(opl_feat_train, opl_label_train, opl_feat_test, opl_label_test,
                     dictOfEmb, ClassLabels_list,  ClassVocab, iter_ICR):
    
    ICR(opl_feat_train, dictOfEmb, ClassLabels_list, ClassVocab, opl_feat_train, iter_ICR, mode = 'train_ICR')
    ICR(opl_feat_test, dictOfEmb, ClassLabels_list, ClassVocab, opl_feat_train, iter_ICR, mode = 'test_ICR')
    
    
    
    
get_ICR_features(opl_features_train, opl_labels_train, opl_features_test, opl_labels_test,
                     dictOfEmb, ClassLabels_list,  ClassVocab, iter_ICR)
    
    
    