import random
import numpy as np
import tensorflow as tf
import io
from PIL import Image
import matplotlib.pyplot as plt
#import torchvision.transforms.functional as F
#import torch.nn.functional as F
import torch 
from torch.nn import init
from itertools import permutations

SEED = 2021
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

def get_data(data_type, collect, attribute_part_label_tensor, indices, device):
    '''
    data_type: str, cub or clevr
    collect: tuple, consist X (image feature), Y (image label), P (part locatiom), A (attribute label)
    attribute_part_label_tensor: tensor[att_numX15], 1 if the part is in the attribute, 0 otherwise 
    indices: list (len=BS), the indices that generated by BatchSampler 
    device: str, the device name
    '''
    X_collect, Y_collect, P_collect, A_collect = collect
    if data_type == 'CUB':
        features, y, parts, attribute_labels= X_collect[indices], Y_collect[indices], P_collect[indices], A_collect[indices]
        features, y, parts, attribute_labels = torch.from_numpy(features).float(), torch.from_numpy(y).float(), torch.from_numpy(parts).float(), torch.from_numpy(attribute_labels).float()            
        parts = torch.clamp(parts, 0, 223.9)
        features, parts,  y, attribute_labels = features.to(device), parts.to(device),  y.to(device), attribute_labels.to(device)
        part_masks  = get_part_masks_cub(parts, attribute_part_label_tensor)
        attribute_labels_repeat = attribute_labels.unsqueeze(-1).repeat(1, 1, 49)
        part_masks = part_masks*attribute_labels_repeat            
    elif data_type == 'alpha-CLEVR':
        features, y, locations, attribute_labels= X_collect[indices], Y_collect[indices], P_collect[indices], A_collect[indices]
        features, y, locations, attribute_labels = torch.from_numpy(features).float(), torch.from_numpy(y).float(), torch.from_numpy(locations).float(), torch.from_numpy(attribute_labels).float()    
        features, locations,  y, attribute_labels = features.to(device), locations.to(device),  y.to(device), attribute_labels.to(device)
        part_masks  = get_part_masks_clevr(locations)
        
    return features, y, attribute_labels, part_masks


def get_part_masks_cub(parts, attribute_part_label_tensor):
    '''
    parts: tensor[BSX15X3], locations of 15 parts
    attribute_part_label_tensor: tensor[att_numX15], 1 if the part is in the attribute, 0 otherwise 
    return: tensor[BSXatt_numX49], 1 if the attribute is on the patch, 0 otherwise 
    '''
    picknum = turn_parts_to_picknum(parts) # BSX15, show where's the parts(0~48, -1 if the part is unseen)     
    no_show_idx = torch.where(picknum==-1)
    picknum_ = picknum.clone()
    picknum_[no_show_idx] = 49
    part_location_mask = torch.nn.functional.one_hot(picknum_.long(),50)[...,:-1].to(parts.device).float() # BSX15X49 remove the parts which not show in the image
    # 15 part-> 49 location
    expanded_attribute_part_label = attribute_part_label_tensor[None].to(parts.device).expand(len(parts),-1,-1) # BSX312X15
    part_masks  = torch.clamp(expanded_attribute_part_label @ part_location_mask, 0., 1.)

    return part_masks

def get_part_masks_clevr(locations):
    '''
    locations: BSX24X3 -> BSX24 -> BSX24X49
    return: BSX24X49
    '''
    picknum = turn_parts_to_picknum(locations) #BSX24    
    no_show_idx = torch.where(picknum==-1)
    picknum_ = picknum.clone()
    picknum_[no_show_idx] = 49
    part_location_mask = torch.nn.functional.one_hot(picknum_.long(),50)[...,:-1].to(locations.device).float() # BSX24X49 remove the parts which not show in the image
    return part_location_mask


def turn_parts_to_picknum(parts):
    '''
    show where's the parts (patch 0~48, -1 if the part is unseen)
    parts: tensor[BSXpart_numX3], part_num=15 for CUB dataset, part_num=24 for ɑ-CLEVR, recording the locations of parts: (x, y, is_show)
    return: tensor[BSXpart_num]
    '''
    part = parts.clone()
    part[:, :, :2] = part[:, :, :2]/32
    part[:, :, 0] = (part[:, :, 0]).to(int)*7 + (part[:, :, 1]).to(int)
    part[:, :, 1] = part[:, :, 2]
    part[:, :, 0] = torch.where(part[:, :, 1] != 0, part[:, :, 0], -1+0*part[:, :, 0])
    return part[:, :, 0]

def get_att_score(x, a, should_be_peak=np.array([])):
    '''
    x: tensor[BSXChX7X7]
    a: tensor[att_numXCh]
    should_be_peak: tensor[BSXatt_numX49]
    return: 
        att_score: tensor[BSXatt_num]
        attention: tensor[BSXatt_numX7X7]
    '''
    BS = x.shape[0]
    att_num = a.shape[0] 
    a = a.unsqueeze(-1)
    a = a.unsqueeze(-1)
    x = torch.nn.functional.normalize(x,2,1)
    a = torch.nn.functional.normalize(a,2,1)
    attention = torch.nn.functional.conv2d(input=x, weight=a) # BSXatt_numX7X7

    # without location information
    if should_be_peak.shape[0] == 0: 
        att_score = torch.nn.functional.max_pool2d(attention, kernel_size=7).view(BS, -1)

    # with location information
    else: 
        flatten_attention = attention.reshape((BS, att_num, 49))
        should_be_peak = should_be_peak.clone().detach() # BSXatt_numX49
        feature_pick_peak_by_label = torch.sum(flatten_attention * should_be_peak, -1) / torch.clamp(torch.sum(should_be_peak, -1), 1e-8)
        feature_average = torch.mean(flatten_attention, -1)  

        # 1 if the parts for the attribute show in an image, else 0
        attribute_part_show_or_not = (torch.sum(should_be_peak, -1) != 0).float() # BSXatt_num             
        # if the attribute is in the image we select the right patch; othersise, use the average of all patches
        att_score = (attribute_part_show_or_not * feature_pick_peak_by_label) + ((1 - attribute_part_show_or_not) * feature_average)

    att_score = 25*(2*att_score-1)
    attention = 25*(2*attention-1)
    att_score = torch.sigmoid(att_score)
    attention = torch.sigmoid(attention)
    return att_score, attention

def get_cub_attribute_name(path = './data/CUB/attributes.txt'):
    '''
    path: str
    (for cub dataset only)
    '''
    file = open(path, 'r')
    name = []
    for line in file:
        name.append(line)
    return name

def get_part_and_adj_dicts(path = './data/CUB/attributes.txt'):
    ''' 
    path: str
    return: 
        parts: dict
        adj: dict 
    (for cub dataset only)
    '''
    file = open(path, 'r')
    parts = {}
    part_num = 0
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        part = has_part.split("_")[1:-1]
        if part == []:
            part = 'primary'
        # deal with parts
        if len(part) == 2:
            part = part[0] + '_' + part[1]
        if len(part) == 1:
            part = part[0]

        if part not in parts:
            parts[part] = part_num
            part_num += 1


    file = open(path, 'r')
    shapes = {}
    shape_num = part_num
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        if color_or_shape == 'shape':
            shape = adj[:-1]
            if shape not in shapes:
                shapes[shape] = shape_num
                shape_num += 1

    file = open(path, 'r')
    colors = {}
    color_num = shape_num
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        if color_or_shape == 'color':
            color = adj[:-1]
            if color not in colors:
                colors[color] = color_num
                color_num += 1

    file = open(path, 'r')
    patterns = {}
    pattern_num = color_num
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        if color_or_shape == 'pattern':
            pattern = adj[:-1]
            if pattern not in patterns:
                patterns[pattern] = pattern_num
                pattern_num += 1

    file = open(path, 'r')
    lengths = {}
    length_num = pattern_num
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        if color_or_shape == 'length':
            length = adj[:-1]
            if length not in lengths:
                lengths[length] = length_num
                length_num += 1

    file = open(path, 'r')
    sizes = {}
    size_num = length_num
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        if color_or_shape == 'size':
            size = adj[:-1]
            if size not in sizes:
                sizes[size] = size_num
                size_num += 1

    adjs = dict(shapes)
    adjs.update(colors)
    adjs.update(patterns)
    adjs.update(lengths)
    adjs.update(sizes)
    return parts, adjs


def get_base_attribute_name(data_type='CUB'):
    '''
    data_type: str, the datest we use (CUB or ɑ-CLEVR)
    return: list, a list of base attribute 
    '''   
    if data_type == 'CUB':
        parts, adjs = get_part_and_adj_dicts()    
        parts_list = [key for key in parts]
        adjs_list = [key for key in adjs] 
        return parts_list + adjs_list #len:88
    elif data_type == 'alpha-CLEVR':
        element_name = ['cube', 'cylinder', 'sphere', 'gray', 'red', 'blue', 'green', 'brown', 'purple', 'cyan', 'yellow']
        return element_name
    
def get_combinations(path = './data/CUB/attributes.txt'):
    '''
    path: str
    return: 2d list (len=312X2) 
    (for cub dataset only)
    '''
    parts, adjs = get_part_and_adj_dicts()
    combinations = []
    file = open(path, 'r')
    for line in file:
        has_part, adj = line.split("::")
        has, color_or_shape = has_part.split("_")[0], has_part.split("_")[-1]
        # part
        part = has_part.split("_")[1:-1]
        if part == []:
            part = 'primary'
        # deal with parts
        if len(part) == 2:
            part = part[0] + '_' + part[1]
        if len(part) == 1:
            part = part[0]
        # adj
        adj = adj[:-1] 
        combinations.append([parts[part], adjs[adj]])
    return combinations

def get_attribute_vectors(data_type):
    '''
    data_type: str, the datest we use (CUB or ɑ-CLEVR)
    return: numpy[att_numXbase_att_num]
    '''
    if data_type == 'alpha-CLEVR':
        attribute_vectors = np.zeros((24, 11))
        for i in range(24):
            idx1 = i%3 #shape
            idx2 = int(i/3) + 3 # color
            attribute_vectors[i][idx1] = 1
            attribute_vectors[i][idx2] = 1
        return attribute_vectors
    if data_type == 'CUB':
        combinations = get_combinations()
        new_combinations = np.zeros((312, 88))
        for i in range(312):
            new_combinations[i, combinations[i][0]] = 1
            new_combinations[i, combinations[i][1]] = 1
        new_combinations = new_combinations.astype(int)
        return new_combinations #array: 312X88

def base_attribute_could_be_and_from(att_batt_matrix, use_attribute_idx):
    '''
    att_batt_matrix: numpy[att_numXbase_att_num]
    use_attribute_idx: numpy[seen_att_num], the accessable attribute idx
    return: list (len=base_att_num), record the idx of attribute to build base attribute
    e.x. red = red wing, red breast, ...... red tail
    '''
    ret = []
    base_att_num = att_batt_matrix.shape[1]
    ret = [list(set(np.where(att_batt_matrix[:, i] == 1)[0]).intersection(set(use_attribute_idx))) for i in range(base_att_num)]
    return ret 

def attribute_could_be_or_by(att_batt_matrix):
    '''
    att_batt_matrix: numpy[att_numXbase_att_num]
    return: list (len 312), record the idx of attribute to build attribute
    e.x. red wing= red, wing
    '''
    ret = []
    base_att_num = att_batt_matrix.shape[0]
    ret = [np.where(att_batt_matrix[i, :] == 1)[0].tolist() for i in range(base_att_num)]
    return ret 



class Tensorboard:
    def __init__(self, logdir):
        self.writer = tf.summary.FileWriter(logdir)

    def close(self):
        self.writer.close()

    def log_scalar(self, tag, value, global_step):
        summary = tf.Summary()
        summary.value.add(tag=tag, simple_value=value)
        self.writer.add_summary(summary, global_step=global_step)
        self.writer.flush()

    def log_histogram(self, tag, values, global_step, bins):
        counts, bin_edges = np.histogram(values, bins=bins)

        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values ** 2))

        bin_edges = bin_edges[1:]

        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        summary = tf.Summary()
        summary.value.add(tag=tag, histo=hist)
        self.writer.add_summary(summary, global_step=global_step)
        self.writer.flush()

    def log_image(self, tag, img, global_step):
        img = (img.copy()*255).astype(np.uint8)
        s = io.BytesIO()
        Image.fromarray(img).save(s, format='png')

        img_summary = tf.Summary.Image(encoded_image_string=s.getvalue(),
                                       height=img.shape[0],
                                       width=img.shape[1])

        summary = tf.Summary()
        summary.value.add(tag=tag, image=img_summary)
        self.writer.add_summary(summary, global_step=global_step)
        self.writer.flush()

    def log_plot(self, tag, figure, global_step):
        plot_buf = io.BytesIO()
        figure.savefig(plot_buf, format='png')
        plot_buf.seek(0)
        img = Image.open(plot_buf)
        img_ar = np.array(img)

        img_summary = tf.Summary.Image(encoded_image_string=plot_buf.getvalue(),
                                       height=img_ar.shape[0],
                                       width=img_ar.shape[1])

        summary = tf.Summary()
        summary.value.add(tag=tag, image=img_summary)
        self.writer.add_summary(summary, global_step=global_step)
        self.writer.flush()
    
    def plot_bin(self, tag, x, global_step):
        fig = plt.figure()
        axes = fig.add_axes([0.1, 0.1, 0.8, 0.8])
        axes.set_title(tag+'_histogram')
        idx = [_ for _ in range(x.shape[0])]
        axes.bar(idx, 
                 x.tolist())
        self.log_plot(tag+'_histogram', fig, global_step)
        plt.close(fig)
        
'''
get index of attributes/base attributes
'''
def get_diag(data_type, diag):
    '''
    data_type: str, the datest we use (CUB or ɑ-CLEVR)
    diag: int, the diag we want to select  
    '''
    if data_type == 'alpha-CLEVR' or data_type == 'clevr_v2_base':
        if diag == 0:
            return np.array([0, 4, 8, 9, 13, 17, 18, 22])# 8
        if diag == 1:
            return np.array([1, 5, 6, 10, 14, 15, 19, 23])# 8
        if diag == 2:
            return np.array([2, 3, 7, 11, 12, 16, 20, 21])# 8
    if data_type == 'CUB' or data_type == 'cub_base':
        if diag== 0:
            return np.array([278, 10, 26, 42, 109, 63, 85, 127, 142, 161, 177, 193, 209, 261, 277, 293]) # 16 16
        elif diag== 1:
            return np.array([279, 11, 27, 43, 110, 64, 86, 128, 143, 162, 178, 194, 210, 262, 263, 294]) # 16 32
        elif diag== 2:
            return np.array([280, 12, 28, 44, 111, 65, 87, 129, 144, 163, 179, 195, 211, 248, 264, 295]) # 16 48
        elif diag== 3:
            return np.array([281, 13, 29, 45, 112, 66, 88, 130, 145, 164, 180, 196, 197, 249, 265, 296]) # 16 64
        elif diag== 4:
            return np.array([282, 14, 30, 46, 113, 67, 89, 131, 146, 165, 181, 182, 198, 250, 266, 297]) # 16 80
        elif diag== 5:
            return np.array([283, 15, 31, 47, 114, 68, 90, 132, 147, 166, 167, 183, 199, 251, 267, 298]) # 16 96
        elif diag== 6:
            return np.array([284, 16, 32, 48, 115, 69, 91, 133, 148, 152, 168, 184, 200, 252, 268, 299]) # 16 112
        elif diag== 7:
            return np.array([285, 17, 33, 49, 116, 70, 92, 134, 135, 153, 169, 185, 201, 253, 269, 300]) # 16 128
        elif diag== 8:
            return np.array([286, 18, 34, 50, 117, 71, 93, 120, 136, 154, 170, 186, 202, 254, 270, 301]) # 16 144
        elif diag== 9:
            return np.array([287, 19, 35, 51, 118, 72, 79, 121,       155, 171, 187, 203, 255, 271, 302]) # 15 159
        elif diag== 10:
            return np.array([288, 20, 36, 52, 119, 58, 80, 122, 137, 156, 172, 188, 204, 256, 272, 303]) # 16 175
        elif diag== 11:
            return np.array([289, 21, 37, 53, 105, 59, 81, 123, 138, 157, 173, 189, 205, 257, 273, 304]) # 16 191
        elif diag== 12:
            return np.array([290, 22, 38, 39, 106, 60, 82, 124, 139, 158, 174, 190, 206, 258, 274, 305]) # 16 207
        elif diag== 13:
            return np.array([291, 23, 24, 40, 107, 61, 83, 125, 140, 159, 175, 191, 207, 259, 275, 306]) # 16 223
        elif diag== 14:
            return np.array([292, 9, 25, 41, 108, 62, 84, 126, 141, 160, 176, 192, 208, 260, 276, 307]) # 16 239
        
def get_use_attribute_idx(data_type, diag_upto=1, use_all = False):
    '''
    data_type: str, the datest we use (CUB or ɑ-CLEVR)
    diag_upto: int, the diag we select to 
    use_all: bool, if we want to get all the attributes
    '''
    if use_all:
        if data_type == 'CUB' or data_type == 'cub_base':
            return np.array(list(set(range(312)))) #312    
        elif data_type == 'alpha-CLEVR' or data_type == 'clevr_v2_base':
            return np.array(list(set(range(24)))) #100
    if diag_upto == 1:
        return np.concatenate((get_diag(data_type, diag=0), get_diag(data_type, diag=1)))
    else:
        return np.concatenate((get_use_attribute_idx(data_type, diag_upto-1), get_diag(data_type, diag_upto)))

def get_cub_base_idx():
    return np.array([0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66]) #31 
    
def get_use_attribute_idx_type2(data_type, mode = 1):
    '''
    data_type: str, the datest we use (CUB or ɑ-CLEVR)
    mode: int, decide the seen attribute selection  
    '''
    if data_type == 'CUB':
        if mode==1:
            return np.concatenate((get_diag(data_type, diag=0), get_diag(data_type, diag=1)))
        elif mode==2:
            return np.concatenate((get_diag(data_type, diag=2), get_diag(data_type, diag=5)))
        elif mode==3:
            return np.concatenate((get_diag(data_type, diag=3), get_diag(data_type, diag=12)))
        elif mode==4:
            return np.concatenate((get_diag(data_type, diag=6), get_diag(data_type, diag=8)))
        elif mode==5:
            return np.concatenate((get_diag(data_type, diag = 11), get_diag(data_type, diag=14)))
        elif mode==6:
            return np.concatenate((get_diag(data_type, diag = 7), get_diag(data_type, diag=4)))
        elif mode==7:
            return np.concatenate((get_diag(data_type, diag = 10), get_diag(data_type, diag=13)))
    elif data_type == 'alpha-CLEVR':
        if mode == 1:
            return np.concatenate((get_diag(data_type, diag=0), get_diag(data_type, diag=1)))
        if mode == 2:
            return np.concatenate((get_diag(data_type, diag=1), get_diag(data_type, diag=2)))
        if mode == 3:
            return np.concatenate((get_diag(data_type, diag=0), get_diag(data_type, diag=2)))

        
        
        
        
        
