import shutil
import os
import sys
import subprocess
from collections import Counter
from typing import Any, Dict, List, Tuple
import torch
import tqdm

from importers.ea_ra_kgc import EaRaKgcData


color_map = {"green":32, "black":0, "blue":34, "red":31, 'yellow':33}


def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=None, fill='*', color="blue"):
    """
    Simple utility to display an updatable progress bar on the terminal.\n
    :param iteration: The numer of iteration completed (% task done)
    :param total: The total number of iterations to be done
    :param prefix: The string to be displayed before the progress bar
    :param suffix: The string to be displayed after the progress bar
    :param decimals: The number of decimal places in the percent complete indicator
    :param length: The length of the bar. If None, the length is calculated as to fill up the screen
    :param fill: The character to fill the bar with
    :param color: Color of the bar
    :return: None
    """
    if(length is None):
        r, _ = shutil.get_terminal_size((120, 80))
        length = max(r-len(prefix)-len(suffix)-11, 10)
    percent = ("{0:5."+str(decimals)+"f}").format(100.0*iteration/total)
    filled_length = int(length*iteration//total)
    bar = fill*filled_length + '-'*(length-filled_length)
    print('\r%s |\033[1;%dm%s\033[0;0m| %s%% %s'%(prefix, color_map[color], bar, percent, suffix), end='\r')
    if iteration==total:
        print()


def colored_print(color, message):
    """
    Simple utility to print in color
    :param color: The name of color from color_map
    :param message: The message to print in color
    :return: None
    """
    print('\033[1;%dm%s\033[0;0m' % (color_map[color], message))


def duplicate_stdout(filename):
    """
    This function is used to duplicate and redires stdout into a file. This enables a permanent record of the log on
    disk\n
    :param filename: The filename to which stdout should be duplicated
    :return: None
    """
    print("duplicating stdout to", filename, file=sys.stderr)
    sys.stdout = os.fdopen(sys.stdout.fileno(), 'w')
    tee = subprocess.Popen(["tee", filename], stdin=subprocess.PIPE)
    os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
    os.dup2(tee.stdin.fileno(), sys.stderr.fileno())

def removeElements(lst, k): 
    """
    helper for removing elements in list occuring less than k times
    """
    counted = Counter(lst) 
    return [el for el in lst if counted[el] >= k] 


def get_entity_alignments(filename, emap=None):
    """ deprecated """
    mappings ={}
    f=open(filename)
    lines = f.readlines()
    for line in lines:
        a=line.split()
        x=a[0]
        y=a[1]
        if x in emap and y in emap:
            p1=emap[x]
            p2=emap[y]
            mappings[p1]=p2
            mappings[p2]=p1
    return mappings


def get_filter(filename, em=None,rm=None,add_unknowns=True, nonoov_entity_count=None):
    filt=[]
    file=open(filename)
    lines = file.readlines()
    for line in lines:
        a=line.split()
        x=a[0]
        if x in em:
            filt.append(em[x])
    print(len(filt))
    return filt


def intersect(a, b):
    return list(set(a) & set(b))


def union(a, b):
    return list(set(a) | set(b))


def log_eval_scores(writer, valid_score, test_score, num_iter):
    for metric in ['mrr','hits10','hits1']:
        writer.add_scalar('{}/valid_m'.format(metric), valid_score['m'][metric] , num_iter)
        writer.add_scalar('{}/valid_e1'.format(metric), valid_score['e1'][metric] , num_iter)
        writer.add_scalar('{}/valid_e2'.format(metric), valid_score['e2'][metric] , num_iter)
        writer.add_scalar('{}/test_m'.format(metric), test_score['m'][metric] , num_iter)
        writer.add_scalar('{}/test_e1'.format(metric), test_score['e1'][metric] , num_iter)
        writer.add_scalar('{}/test_e2'.format(metric), test_score['e2'][metric] , num_iter)


def has_cuda() -> bool:
    return torch.cuda.is_available()


def get_rel_align_dict(meta, kgc_train_path : str,
                       emap : Dict[Any,int]=None,
                       rmap : Dict[Any,int]=None):
    """
    Computes SO-pair sets of relations.
    :param meta: not needed, but may need later
    :param kgc_train_path: combined (s,r,o) train fold.
    entIDs are global. relID is lang-prefixed global.
    :param emap: entity map
    :param rmap: relation map
    :return: dict(relID, set(SO-pair)); dict(relID, lang)
    """
    rel_ent_pairs : Dict[int, List[Tuple[int, int]]] = dict()
    lang : Dict[int, str] = dict()
    train=open(kgc_train_path)
    lines = train.readlines()
    for line in lines:
        a=line.split()
        assert type(rmap[a[1]]) == int
        if rmap[a[1]] not in rel_ent_pairs:
            rel_ent_pairs[rmap[a[1]]] = []
            lang[rmap[a[1]]] = a[1][0]
        rel_ent_pairs[rmap[a[1]]].append((emap[a[0]], emap[a[2]]))
    train.close()
    return rel_ent_pairs, lang


def update_imp_sc(meta: EaRaKgcData,
                  rel_ent_pairs : Dict[int, List[Tuple[int, int]]],
                  lang : Dict[int, str], S_re, S_im, a, b):
    """
    With current entity embeddings, recomputes soft-asymmetric set similarities
    (as approximate max matching) between relations, which are represented
    as SO-vector sets.
    :param rel_ent_pairs: dict(relID, set(SO-pair)) ... map : rel -> SO pairs
    :param lang: dict(relID, lidStr) ... relation to language
    :param S_re: entity embeddings, real part
    :param S_im: entity embeddings, imaginary part
    :param a: sigmoid slope
    :param b: sigmoid offset
    """
    grid_to_sovecs : Dict[int, Any] = dict()
    """key = rel; val = set of SO vecs stacked into matrix"""
    for rel0 in tqdm.tqdm(rel_ent_pairs, desc="update_imp_sc, rel_ent_pairs"):
        assert type(rel0) == int
        if len(rel_ent_pairs[rel0])<2:  # MAGIC
            continue
        for (s,o) in rel_ent_pairs[rel0]:
            if rel0 not in grid_to_sovecs:
                grid_to_sovecs[rel0] = []
            t = torch.cat((S_re(torch.tensor([s]).cuda()),
                           S_im(torch.tensor([s]).cuda()),
                           S_re(torch.tensor([o]).cuda()),
                           S_im(torch.tensor([o]).cuda())),-1)
            grid_to_sovecs[rel0].append(t[0]/torch.norm(t[0])) 
        grid_to_sovecs[rel0] = torch.stack(grid_to_sovecs[rel0],dim=0)

    impl : Dict[str, Dict] = dict()
    """key = langid1 (one digit) as string; val = dict( key= rel, val= list( of what? ) )"""
    for lid1 in meta.lids():
        lid1str = str(lid1)
        impl[lid1str] = dict()
    print("update_imp_sc  impl", len(impl), impl.keys())

    for rel1 in tqdm.tqdm(grid_to_sovecs, desc="update_imp_sc, tuple_set"):
        assert type(rel1) == int
        if int(lang[rel1])==meta.maxlid_plus_one():
            continue
        for rel2 in grid_to_sovecs:
            if rel1==rel2 or int(lang[rel2])==meta.maxlid_plus_one():
                continue
            if rel1 not in impl[lang[rel2]]:
                impl[lang[rel2]][rel1] = []
            mat = grid_to_sovecs[rel1] @ grid_to_sovecs[rel2].t()
            val,ind = torch.max(mat,1)
            val1,ind1 = torch.max(mat,0)
            sig = torch.sigmoid(a*val-b)
            sig1 = torch.sigmoid(a*val1-b)
            compareview = sig1.repeat(sig.shape[0],1).t()
            not_intersect = sig[(compareview!=sig).t().prod(1)==1]
            impl[lang[rel2]][rel1].append(((torch.sum(sig)-torch.sum(not_intersect))/len(mat),rel2))
        for lid2 in meta.lids():
            lid2str = str(lid2)
            impl[lid2str][rel1].sort(reverse=True)

    equiv_rel=dict()
    for lid3 in meta.lids():
        lid3str = str(lid3)
        equiv_rel[lid3str] = {}

    for rel4 in tqdm.tqdm(grid_to_sovecs):
        if int(lang[rel4])==meta.maxlid_plus_one():
            continue
        for lid4 in meta.lids():
            lid4str = str(lid4)
            if lid4str==lang[rel4]:
                continue
            sc,rel2 = impl[lid4str][rel4][0]
            sc2,rel3 = impl[lang[rel4]][rel2][0]
            if rel4==rel3:
                equiv_rel[lid4str][rel4]=(min(sc,sc2),rel2)
    return  equiv_rel

