#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch 
import sys 
import os 
sys.path.append(os.getcwd()[:-9])

import clip
import numpy as np
from sklearn.metrics import jaccard_score as jsc

# HCDM stuff 
from networks import LinearClassifier, DiscoveryMechanism
from utils import compute_and_save_features, AverageMeter, accuracy, bin_accuracy, get_feature_dir, \
patchify_custom, visualize_patches
from data_utils import get_loaders, get_concepts, get_feature_dir, get_concept_indicators, data_loader
import scipy 

dataset = 'imagenet'
base_path = os.getcwd()
print(base_path)
import matplotlib.pyplot as plt 

from torchvision import datasets


# In[2]:


class convert_to_dot_notation(dict):
    """
    Access dictionary attributes via dot notation
    """

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = {
    'dataset': dataset,
    'concept_name': dataset,
    'clip_version': 'ViT-B/16',
    'device': 'cuda',
    'save_dir': base_path+'saved_models',
    'num_workers': 8, 
    'batch_size': 128, 
    'patch_size': [3,3],
    'patchify': True, 
    'low_level_only': False,
    'discovery': True, 
    'tie_indicators': True, 
    'compute_similarities': False
}
args = convert_to_dot_notation(args)


# In[3]:


concepts = []
num_concepts = []

# need to load both the "regular" and the patch dataset
# I changed this to reduce the complexity
# check if something is wrong
dual_data = not args.low_level_only and args.patchify
args.dual_data = dual_data
if dual_data:
    train_loader, val_loader, classes = get_loaders(args, batch_size=args.batch_size, dual_data=dual_data, 
                                                    cs_path = base_path)
    concepts = [get_concepts(args.concept_name, False, cs_path = base_path), 
                get_concepts(args.concept_name, True,  cs_path = base_path)]
    binary_inds = get_concept_indicators(args.concept_name, cs_path = base_path).to(args.device)
    num_concepts = [len(concepts[0]), len(concepts[1])]
    num_classes = len(classes)

else:
    train_loader, val_loader, classes = get_loaders(args, patchify=args.patchify,
                                                    batch_size=args.batch_size, cs_path = base_path)
    concepts = [get_concepts(args.concept_name, patchify=args.patchify, cs_path = base_path)]
    num_concepts = [len(concepts[0])]
    num_classes = len(classes)

    binary_inds = get_concept_indicators(args.concept_name, cs_path = base_path) if args.patchify \
        else torch.ones([num_classes, num_concepts[0]])
    binary_inds = binary_inds.to(args.device)
print(binary_inds)


# In[4]:


# build the classifiers first
if 'ViT' in args.clip_version:
    feat_emb = 512
prior = 1.
classifiers = []
discoverers = []

# if low_level_only is active, we consider only the low level
if not args.low_level_only:
    classifiers.append(
        LinearClassifier(num_concepts[0] if not args.only_images else feat_emb, num_classes).to(args.device))
    if args.discovery and not args.only_images:
        discoverers.append(DiscoveryMechanism(feat_emb, num_concepts[0], prior=prior).to(args.device))
if args.patchify:
    classifiers.append(
        LinearClassifier(num_concepts[-1] if not args.only_images else feat_emb, num_classes).to(args.device))
    if args.discovery and not args.only_images:
        discoverers.append(DiscoveryMechanism(feat_emb, num_concepts[-1], prior=prior).to(args.device))

if dual_data:
    spec = 'HL'
elif args.patchify:
    spec = 'L'
else:
    spec = 'H'

print(spec)
# load the model
ckpt_path = '{}saved_models/selected_results/{}/{}/ckpt/checkpoint.pth.tar'.format(base_path, dataset, spec)
ckpt = torch.load(ckpt_path)

if not args.low_level_only:
    classifiers[0].load_state_dict(ckpt['state_dict_high'])
    if args.discovery:
        discoverers[0].load_state_dict(ckpt['disc_state_dict_high'])
if args.patchify:
    classifiers[-1].load_state_dict(ckpt['state_dict_low'])
    if args.discovery:
        discoverers[-1].load_state_dict(ckpt['disc_state_dict_low'])
        
# set all components to evaluation model
for c in classifiers:
    print(c)
    c.eval()
for d in discoverers:
    print(d)
    d.eval()


# In[5]:


# load text descriptions
texts = []
text_dir = get_feature_dir(args)
if not args.low_level_only:
    text_path = text_dir + '{}_{}_level_text_features.pt'.format(args.concept_name, 'high')
    texts.append(torch.load(text_path).to(args.device))
if args.patchify:
    text_path = text_dir + '{}_{}_level_text_features.pt'.format(args.concept_name, 'low')
    texts.append(torch.load(text_path).to(args.device))



# In[58]:


sparse_high = 0. 
sparse_low = 0. 
threshold = 0.01
num_high = 200 
num_low = 312

masks = [np.zeros([num_high, num_high]), np.zeros([num_high, num_low])]
contrs = [np.zeros([num_high, num_high]), np.zeros([num_high, num_low])]
examples_per_class = [0.]*num_high

for batch, data in enumerate(val_loader):

    if len(data) == 2:
        data = [data]
    else:
        data = [data[:2], data[2:]]

    for level in range(len(data)):

        with torch.no_grad():

            images = data[level][0].to('cuda', non_blocking=True)
            labels = data[level][1].type(torch.LongTensor).to('cuda', non_blocking=True)

            text = texts[level]

            text = text / text.norm(dim=-1, keepdim=True)
            feats = images / images.norm(dim=-1, keepdim=True)
            similarity = (feats @ text.T) if not args.only_images else images

            mask, _ = discoverers[level](images, probs_only=False)


            if level == 1 and args.tie_indicators:
                mask *= nested_mask


            elif (level == 0 and args.patchify) and not args.low_level_only:
                print(nested_mask @ binary_inds)
                nested_mask = mask @ binary_inds / binary_inds.shape[0]
                nested_mask = nested_mask.view([nested_mask.shape[0],
                                                1,
                                                1,
                                                nested_mask.shape[-1]])
            if level ==1 or args.low_level_only:
                sparse_low += (mask>threshold).float().mean()
                
            else:
                sparse_high += (mask > threshold).float().mean()
                
            # to get the contribution get the classifier weights 
            cur_prediction = classifiers[level](similarity, mask).argmax(-1)
            weights = classifiers[level].W
            
            for i in range(images.shape[0]):
                class_weights = weights[:, cur_prediction[i]]
                cur_mask = mask[i]
               
            
                if level == 1:
                    class_weights = class_weights.permute(1,2,0)
                
                contributions = ((cur_mask)*similarity[i])*class_weights
                
                if level == 1:
                    cur_mask = cur_mask.amax([0,1])
                    contributions = contributions.mean([0,1])
                    
              
                
                masks[level][labels[i], :] += cur_mask.cpu().numpy()
                contrs[level][labels[i]] += contributions.cpu().numpy()
                examples_per_class[labels[i]] += 0.5
                    

print('Sparsity High: {:.3f}'.format(sparse_high/(batch+1.)))
print('Sparsity Low: {:.3f}'.format(sparse_low/(batch+1.)))
print(((masks[1][0]/examples_per_class[0])>0.1).sum())


# In[112]:


# first check the whole image concepts 
inferred_concepts = []
savedir = base_path +'per_class_descriptions_{}/'.format(args.dataset)
os.makedirs(savedir, exist_ok = True)

classes = []
with open(base_path+'data/imagenet_classes.txt', 'r') as f:
    for line in f:
        classes.append(line.strip())
        
attrs = []
with open(base_path+'data/concept_sets_low/ImageNet/imagenet_attributes_cleaned.txt', 'r') as f:
    for line in f:
        attrs.append(line.strip())
attrs = np.array(attrs)
gt_binary = np.load(base_path+'data/concept_sets_low/ImageNet/imagenet_attrs_per_class_binary.npy')

for cls in range(gt_binary.shape[0]):
    print('Class Name: {}'.format(classes[cls]))
    active_concepts_gt = np.where(gt_binary[cls]> 0.)[0]
    active_concepts_inf = np.where(masks[1][cls]>0.01)[0]
    print('Active concepts: {}/{}'.format(len(active_concepts_gt), attrs.shape[0]))
    print('Active concepts: {}/{}'.format(len(active_concepts_inf), attrs.shape[0]))
    inds_gt = set(active_concepts_gt)
    inds_inf = set(active_concepts_inf)
    common_inds = inds_inf.intersection(inds_gt)
    new_inds = inds_inf.difference(inds_gt)
    removed_inds = inds_gt.difference(inds_inf)
    print('Common concepts: {}'.format(len(common_inds)))
    print('New concepts: {}'.format(len(new_inds)))
    print('Removed concepts: {}'.format(len(removed_inds)))
    print('\nSome common concepts: {}\n'.format(attrs[list(common_inds)[:10]]))
    print('\nSome new concepts: {}\n'.format(attrs[list(new_inds)[:10]]))
    print('\nSome removed concepts: {} \n'.format(attrs[list(removed_inds)[:10]]))
    
    print('\nGround truth concepts full: {}\n'.format(attrs[list(inds_gt)]))
    print('\nInferred concepts full: {}\n'.format(attrs[list(inds_inf)]))
    
    with open(savedir+'class_{}.txt'.format(classes[cls]), 'w') as f:
        f.write('Class Name: {}\n'.format(classes[cls]))
        f.write('Active concepts: {}/{}\n'.format(len(active_concepts_gt), attrs.shape[0]))
        f.write('Active concepts: {}/{}\n'.format(len(active_concepts_inf), attrs.shape[0]))
        f.write('Common concepts: {}\n'.format(len(inds_gt.intersection(inds_inf))))
        f.write('New concepts: {}\n'.format(len(new_inds)))
        f.write('Removed concepts: {}\n'.format(len(removed_inds)))
        f.write('\nSome common concepts: {}\n'.format(attrs[list(common_inds)[:10]]))
        f.write('\nSome new concepts: {}\n'.format(attrs[list(new_inds)[:10]]))
        f.write('\nSome removed concepts: {} \n'.format(attrs[list(removed_inds)[:10]]))

        f.write('\nGround truth concepts full: {}\n'.format(attrs[list(inds_gt)]))
        f.write('\nInferred concepts full: {}\n'.format(attrs[list(inds_inf)]))
    
    







