# -*- coding: utf-8 -*-
import pandas as pd
from os.path import dirname, basename

import os
import time
from os import path
import gym, minerl
import numpy as np
from tqdm import tqdm as tqdm
from collections import OrderedDict
import glob
import imageio
import shutil
import subprocess
import json
import operator
import pickle
from train.subtask_identifier.data_handling import get_frames, show_frames, get_actions, get_observations, SpacesMapper#, get_metadata_score
from train.behavioral_cloning.datasets.minerl_dataset import MineRLBaseDataset





def no_craft_action(data,pos):
    if pos > 3:
        if type(data) == OrderedDict:
            for f in ['craft','nearbyCraft','nearbySmelt','place']:
                if np.sum(data[f][pos-3:pos+1]) == 0:
                    return True
                else:
                    return False
        else:
            for f in ['action_craft','action_nearbyCraft','action_nearbySmelt','action_place']:
                if np.sum(data[f][pos-3:pos+1]) == 0:
                    return True
                else:
                    return False
    else:
        return True

def get_minimum_transition(feature_files_and_scores,inv_space,dataset):

    rewards = OrderedDict()

    ff_path, score = feature_files_and_scores[0]
    name = os.path.split( os.path.split(ff_path)[0])[-1]
    data = dataset[name]
    inv = data['observation_dict']['inventory'].copy()

    # change inv into arrays
    inv = np.transpose(np.array([ np.array(v) for k,v in inv.items()]))

    min_change = [ [] for i in range(inv.shape[1])]

    for ff_path, score in tqdm(feature_files_and_scores):
        # print(ff_path)
        name = os.path.split(os.path.split(ff_path)[0])[-1]
        data = dataset[name]  #np.load(ff_path, allow_pickle=True, mmap_mode='r')
        inv = data['observation_dict']['inventory'].copy()
        inv = np.transpose(np.array([np.array(v) for k, v in inv.items()]))
        rewards[ff_path] = data['reward_seq']

        smoothing_window = 5
        for i in range(inv.shape[1]):
            tmp = np.zeros(inv.shape[0])
            for j in range(1, inv.shape[0]):
                if inv[j, i] < tmp[j - 1] and inv[min(inv.shape[0] - 1, j + smoothing_window), i] == tmp[j - 1]:
                    tmp[j] = tmp[j - 1]
                elif inv[j, i] > tmp[j - 1] and inv[min(inv.shape[0] - 1, j + smoothing_window), i] == tmp[j - 1]:
                    tmp[j] = tmp[j - 1]
                else:
                    tmp[j] = inv[j, i]
            inv[:, i] = tmp


        # check for spureous changes
        # check if changes are according to the minimum

        diff = (inv[1:] - inv[:-1])
        inv_changes = np.where(diff > 0)
        for pos, idx in zip(inv_changes[0], inv_changes[1]):
            min_change[idx].append(np.copy(diff[pos,idx]))

    minimum_change = np.zeros_like(min_change)
    for i,d in enumerate(min_change):
        if d != []:
            minimum_change[i] = int(np.quantile(d,q=0.1))
        else:
            minimum_change[i] = 1

    return minimum_change


def make_lookup_dict_from_raw(dataset, input_dir: str, output_dir: str, best_replay_frac: float = 1., make_files: bool = False, debug: bool = False):
    """Parse files in input_dir and creates:

      - inventory_changes, a dictionary with filenames as keys and tuples of numpy arrays (frames, item_ids) for each
        positive inventory changes
      - subtask_lookup, a list with elements corresponding to subtasks.
        Each element is a list of tuples (filename, score, start-frame, end-frame),
        one tuple for each demonstration sequence of the subtask.

    Parameters
    ----------
    input_dir : str
        Path to folder containing human demonstration files ('univ.json')
    best_replay_frac : float
        Only consider a fraction best_replay_frac of successful replays.
    make_files : bool
        If True, writes subtask_lookup to file 'subtask_lookup.json' and inventory_changes to file
        'inventory_changes.npz' in directory input_dir.

    Returns
    ----------
    inventory_changes : (OrderedDict of tuple, list of list)
      inventory_changes: a dictionary with filenames as keys and tuples of numpy arrays (frames, item_ids) for each
      positive inventory changes.
      subtask_lookup: a list with elements corresponding to subtasks.
      Each element is a list of tuples (filename, score, start-frame, end-frame),
      one tuple for each demonstration sequence of the subtask.
    """

    feature_file_name = 'rendered.npz'
    spaces_mapper = SpacesMapper(envname= "MineRLObtainDiamond-v0")
    inv_space, *_ = spaces_mapper.get_environment_spaces()

    feature_files = glob.glob(os.path.join(input_dir, '**', feature_file_name))
    feature_files.sort()

    # Create list (filename, score) and restrict to successful replays only
    feature_files_and_scores = [(ff, get_metadata_score(os.path.dirname(ff))) for ff in feature_files
                                if get_metadata_score(os.path.dirname(ff)) is not None]
    if debug:
        feature_files_and_scores = feature_files_and_scores[:5]
    # Sort by score, take only best_replay_frac
    feature_files_and_scores.sort(key=lambda x: x[1], reverse=True)
    feature_files_and_scores = feature_files_and_scores[:int(np.ceil(len(feature_files_and_scores) * best_replay_frac))]

    min_change = get_minimum_transition(feature_files_and_scores,inv_space,dataset)

    inventory_changes = OrderedDict()
    subtask_lookup = OrderedDict([(item, []) for item in inv_space])
    rewards = OrderedDict()

    for ff_path, score in tqdm(feature_files_and_scores):
        name = os.path.split(os.path.split(ff_path)[0])[-1]
        data = dataset[name]  #np.load(ff_path, allow_pickle=True, mmap_mode='r')
        inv = data['observation_dict']['inventory'].copy()
        # change inv into arrays
        inv = np.transpose(np.array([ np.array(v) for k,v in inv.items()]))
        rewards[ff_path] = data['reward_seq']

        smoothing_window = 5
        for i in range(inv.shape[1]):
            tmp = np.zeros(inv.shape[0])
            for j in range(1, inv.shape[0]):
                if inv[j, i] < tmp[j - 1] and inv[min(inv.shape[0] - 1, j + smoothing_window), i] == tmp[j - 1]:
                    tmp[j] = tmp[j - 1]
                elif inv[j, i] > tmp[j - 1] and inv[min(inv.shape[0] - 1, j + smoothing_window), i] == tmp[j - 1]:
                    tmp[j] = tmp[j - 1]
                else:
                    tmp[j] = inv[j, i]
            inv[:, i] = tmp

        # check INVENTORY for spureous changes
        # check if changes are according to the minimum
        diff = (inv[1:] - inv[:-1])
        inv_changes = np.where(diff > 0)
        # fix for minimum changes
        for pos,idx in zip(inv_changes[0],inv_changes[1]):
            if np.abs(inv[pos+1,idx]-inv[pos,idx]) > min_change[idx]:
                item_diference = int(np.abs(inv[pos+1,idx]-inv[pos,idx])/min_change[idx])
                added = 0
                while added < item_diference and pos+added+1 < inv[:,0].size:
                    added += 1
                    inv[pos+added, idx] += min_change[idx]*(added-item_diference)




        inv_changes = np.where((inv[1:] - inv[:-1]) > 0)
        inv_changes[0][:] += 1
        inventory_changes[ff_path] = inv_changes

    # Save results json files
    if make_files:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        np.savez(os.path.join(output_dir, 'inventory_changes.npz'), **inventory_changes)
        np.savez(os.path.join(output_dir, 'rewards.npz'), **rewards)

    return inventory_changes, subtask_lookup, rewards, dataset

def create_fasta_sequences(aa_dict, rewards, inventory_changes,outdir,
                           top_n=10, exclude=None):
    file_suffix = ""
    outfile = None
    aa_to_item_dict = aa_dict.copy()
    # remove excluded
    if exclude is not None:
        remove = [k for k, v in aa_to_item_dict.items() if v in exclude]
        for k in remove:
            del aa_to_item_dict[k]
        file_suffix = "_" + "_".join(exclude)
    # create sequences
    fasta_sequences = {}
    for file, changes in inventory_changes.items():
        seq_len = len(rewards[file])
        pos, itemid = changes
        reward = np.sum(rewards[file])
        fasta = [aa_to_item_dict[i] for i in itemid if i in aa_to_item_dict]
        fasta = "".join(fasta)
        fasta_sequences[file] = {"reward": reward, "length":seq_len, "fasta": fasta}
    fasta_sequences = pd.DataFrame.from_dict(fasta_sequences, orient="index")
    top_n_seq = fasta_sequences.sort_values(["reward", "length"], ascending=[False, True]).iloc[:top_n]
    # write
    if outdir is not None:
        outfile = path.join(outdir, "top_{}{}.fasta".format(top_n, file_suffix))
        with open(outfile, "w") as f:
            for file, seq in top_n_seq.iterrows():
                f.write(">{}|rew={}|len={}\n{}\n".format(basename(dirname(file)),  seq.reward,seq.length, seq.fasta))
    return top_n_seq, fasta_sequences, outfile

def create_scoring_matrix(
        aa_dict,
        reward_dict_ids,
        outdir=None,
        outfile='scoring_matrix',
        fasta_sequences=None,
        expected_dict=None,
        offdiag=-10.0,
        offdiag_gap = -1.0,
        main_diag_factor=1,
        scaling=None,
        reward_op="mul", freq_scaling=None,
        subtask='my_sub_task'):

    # count AAs
    if False: #subtask != 'consensus_score':
        sym_dict = OrderedDict({k: 0 for k in aa_dict.values()})

        for sym in "".join(fasta_sequences):
            sym_dict[sym] += 1
        n_total = len("".join(fasta_sequences))
        if n_total == 0:

            return None, None, False

        frequencies = OrderedDict({k: 0 for k in aa_dict.values()})
        for sym, cnt in sym_dict.items():
            frequencies[sym] = sym_dict[sym]/n_total

        # log frequency
        log_odd = OrderedDict({k: 0 for k in aa_dict.values()})
        for sym, cnt in frequencies.items():
            if sym == 'T':
                log_odd[sym] = 11
            elif expected_dict[sym] > 0 and frequencies[sym] > 0:
                log_odd[sym] = np.log( frequencies[sym] / expected_dict[sym])
            else:
                log_odd[sym] = 0.0

        sym_dict = log_odd.copy()
    else:
        #log frequency
        sym_dict = OrderedDict({k: 0 for k in aa_dict.values()})
        for sym in "".join(fasta_sequences):
            sym_dict[sym] += 1
        n_total = len("".join(fasta_sequences))
        for sym, cnt in sym_dict.items():
            if cnt > 0:
                sym_dict[sym] = -2 * np.log(cnt / n_total)
                if freq_scaling == "square":
                    sym_dict[sym] = sym_dict[sym]**2
                # sym_dict[sym] = 1/(cnt / n_total) # inverse frequency
                # TODO: 1/avg_freq_per_seq
        #scores
    scoring = np.full(shape=(len(sym_dict), len(sym_dict)),
                      fill_value=offdiag,
                      dtype=np.float32)
    # gap scores
    for i in range(len(scoring)):
        scoring[-1,i] = offdiag_gap
        scoring[i, -1] = offdiag_gap
    # main_diag and reward scaling

    for i in range(len(scoring)):
        scoring[i, i] = sym_dict[list(sym_dict.keys())[i]] * main_diag_factor
        if scaling is not None and i in reward_dict_ids:
            # scoring[i, i] *= np.sqrt(reward_dict_ids[i])
            if scaling == "linear":
                if reward_op == "mul":
                    scoring[i, i] *= reward_dict_ids[i]
                elif reward_op == "add":
                    scoring[i, i] += reward_dict_ids[i]
            elif scaling == "sqrt":
                if reward_op == "mul":
                    scoring[i, i] *= np.sqrt(reward_dict_ids[i])
                elif reward_op == "add":
                    scoring[i, i] += np.sqrt(reward_dict_ids[i])
            elif scaling == "log":
                if reward_op == "mul":
                    scoring[i, i] *= np.log(reward_dict_ids[i])
                elif reward_op == "add":
                    scoring[i, i] += np.log(reward_dict_ids[i])

    # normalize to be in the range of 10 (like BLOSUM)
    main_diag_factor = 10 / np.max( [scoring[i, i] for i in range(len(scoring))] )
    for i in range(len(scoring)):
        scoring[i, i] *= main_diag_factor

    # Write
    outfile += "{}_m{}_o{}".format(subtask,main_diag_factor, offdiag)
    if scaling is not None:
        outfile += "_{}{}".format(reward_op, scaling)
    if freq_scaling is not None:
        outfile += "_f{}".format(freq_scaling)
    outfile = path.join(outdir, outfile)
    # gap matrix
    with open(outfile, "w") as f:
        f.write("   " + " ".join(list(aa_dict.values())) + " *\n")
        for i in range(len(scoring)):
            f.write(list(sym_dict.keys())[i] + " ")
            f.write(" ".join([str(x) for x in scoring[i]]))
            f.write(" " + str(offdiag_gap))
            f.write("\n")
        f.write("*" + (" " + str(offdiag_gap)) * (len(scoring) + 1))
    return scoring, outfile, True


def read_sequences (fasta_file=None):
    # Reading sequences
    ids = []
    my_sequences = []
    ids_int = []
    j = 0
    for line in open(fasta_file, 'r'):
        if j % 2 == 0:
            ids.append(line)
            ids_int.append(int(j / 2))
        else:
            my_sequences.append(line[:-1])
        j += 1
    return my_sequences,ids,ids_int


def get_alignment(aa_dict, reward_dict_ids, my_clustered_seqs, my_clustered_ids,ids_int, milestones_seqs,output_dir):
    # Get sub sequence start and end points

    milestones_junks = []
    for milestone_seq in milestones_seqs:
        first = milestone_seq[0]
        sub_seq_m = 'Start-' + str(first)
        if sub_seq_m not in milestones_junks:
            milestones_junks.append(sub_seq_m)

        i = 0
        for m in milestone_seq[:-1]:
            sub_seq_m = str(m) + '-' + str(milestone_seq[i + 1])
            last = str(milestone_seq[i + 1])
            if sub_seq_m not in milestones_junks:
                milestones_junks.append(sub_seq_m)
            i += 1

        sub_seq_m = str(last) + '-End'
        if sub_seq_m not in milestones_junks:
            milestones_junks.append(sub_seq_m)

    # create directories for each sub seq
    directory = 'my_clustered_sequences'
    if not os.path.exists(os.path.join(output_dir, directory)):
        os.mkdir(os.path.join(output_dir, directory))
    else:
        os.system('rm -r {}'.format(os.path.join(output_dir, directory)))
        os.mkdir(os.path.join(output_dir, directory))

    for m in milestones_junks:
        subdirectory = m
        sub_folder = os.path.join(output_dir,directory, subdirectory)
        if not os.path.exists(sub_folder):
            os.mkdir(sub_folder)
        else:
            os.system('rm -r {}'.format(sub_folder))
            os.mkdir(sub_folder)

    my_sequences = my_clustered_seqs
    ids = my_clustered_ids

    ########################
    my_subsequences_for_scoring = []
    my_subsequences_ids_for_scoring = []

    for seq, idx, name in zip(my_sequences, ids, ids_int):
        temp_sub_sequence = []
        check = {'S': 0, 'P': 0, 'V': 0, 'L': 0, 'N': 0, 'A': 0, 'F': 0, 'Y': 0, 'Q': 0, 'K': 0, 'E': 0}
        start = 0
        last_a = 'Start'
        temp_sub_sequence.append('T')
        for a in seq:
            if a in check and check[a] == 0:
                check[a] = 1
                # write down sequence
                temp_sub_sequence.append(str(a))  # write final letter
                junk = str(last_a) + '-' + str(a)
                sub_folder = os.path.join(output_dir,directory, junk)
                file = os.path.join(sub_folder, '{}_{}.fasta'.format(name, junk))
                with open(file, 'w') as f:
                    f.write(idx)
                    my_subsequences_ids_for_scoring.append(idx)
                    my_subsequences_for_scoring.append(''.join(temp_sub_sequence))
                    for c in temp_sub_sequence:
                        f.write(c)
                    f.write('\n')
                # reset sub sequence
                temp_sub_sequence = []
                temp_sub_sequence.append('T') # Force alignment to the left
                last_a = a
            temp_sub_sequence.append(a)
        junk = str(last_a) + '-End'
        sub_folder = os.path.join(output_dir,directory, junk)
        file = os.path.join(sub_folder, '{}_{}.fasta'.format(name, junk))
        with open(file, 'w') as f:
            f.write(idx)
            my_subsequences_ids_for_scoring.append(idx)
            my_subsequences_for_scoring.append(''.join(temp_sub_sequence))
            for c in temp_sub_sequence:
                f.write(c)
            f.write('\n')

    my_sub_sequences_files = []
    for m in milestones_junks:
        subdirectory = m
        sub_folder = os.path.join(output_dir,directory, subdirectory)
        os.system('cat {}/*.fasta > {}/{}_m.mfasta'.format(sub_folder, sub_folder, subdirectory))
        my_sub_sequences_files.append('{}/{}_m.mfasta'.format(sub_folder, subdirectory))

    #############################3

    msa_folder = os.path.join(output_dir,directory, 'my_msa')
    outdir = msa_folder

    # create directories for each sub seq
    outfiles = []
    if not os.path.exists(outdir):
        os.mkdir(outdir)
    else:
        os.system('rm -r {}'.format(outdir))
        os.mkdir(outdir)

    # compute expected value for AAs
    my_sequences_for_exp = []
    for file_s in my_sub_sequences_files:
        with open(file_s, 'r') as f:
            data = f.read().split()
            for d in data:
                if d[0] != '>':
                    my_sequences_for_exp.append(d[1:]) # remove the artificial 'T'
    exp_dict = OrderedDict({k: 0 for k in aa_dict.values()})
    for sym in "".join(my_sequences_for_exp):
        exp_dict[sym] += 1
    n_total = len("".join(my_sequences_for_exp))
    for k,v in exp_dict.items():
        exp_dict[k] = exp_dict[k]/n_total

    for seq_file in my_sub_sequences_files:
        infile = seq_file
        outfile = "{}/msa_{}.aln".format(outdir, path.basename(seq_file)[:-7])  # now files are mfasta
        msatree = "{}/msa.dnd".format(outdir)
        # collect all subsequences
        sequences_for_scoring = []
        with open(seq_file, 'r') as f:
            data = f.read().split()
            for d in data:
                if d[0] != '>':
                    sequences_for_scoring.append(d)
        if sequences_for_scoring == []:
            continue
        scoring_matrix, score_file, result = create_scoring_matrix(aa_dict, reward_dict_ids, outdir=output_dir,scaling='linear',
                                                                   fasta_sequences=sequences_for_scoring,expected_dict =exp_dict,  offdiag=-10.0,
                                                                   offdiag_gap=-1.0, subtask=path.basename(seq_file)[:-7])
        if result == False: # no sequences available
            continue
        # get maximum
        with open(score_file, 'r') as f:
            data = f.read().split()
            floats = []
            for elem in data:
                try:
                    if float(elem) < 10:
                        floats.append(float(elem))
                except ValueError:
                    pass
        gap_open_1 = np.max(floats)/10
        gap_open = 0.001#np.max(floats)/10
        gap_ext = 0

        if os.path.exists(outfile):
            os.remove(outfile)
        if os.path.exists(msatree):
            os.remove(msatree)
        # -CLUSTERING=UPGMA
        cmd = "clustalw2 -ALIGN -NEGATIVE -CLUSTERING=UPGMA " \
              "-INFILE={infile} " \
              "-OUTFILE={outfile} " \
              "-PWMATRIX={scores} -PWGAPOPEN={gapopen_1} -PWGAPEXT={gapext} " \
              "-MATRIX={scores} -GAPOPEN={gapopen} -GAPEXT={gapext} -CASE=UPPER " \
              "-GAPDIST=0 -NOPGAP -NOHGAP -MAXDIV=0 -ENDGAPS -NOVGAP -NUMITER=10000" \
              "-NEWTREE={tree} -TYPE=PROTEIN -OUTPUT=GDE".format(
            infile=infile, outfile=outfile, scores=score_file, gapopen_1=gap_open_1, gapopen=gap_open, gapext=gap_ext,
            tree=msatree
        )

        output = subprocess.run(cmd.split(" "), stdout=subprocess.PIPE).stdout.decode("utf-8")
        #print(output)
        outfiles.append(outfile)

    champion = milestones_seqs[0]

    my_files = []
    i = 0
    for m in champion[:-1]:
        file = os.path.join(msa_folder, 'msa_{}-{}_m.aln'.format(m, champion[i + 1]))
        my_files.append(file)
        i += 1

    file = os.path.join(msa_folder, 'msa_{}-{}_m.aln'.format(champion[i], 'End'))
    my_files.append(file)

    #print(my_files)
    # Because all sequences start with S-P, open the file and get the identificators
    master_seq = [[] for i in range(len(my_sequences))]
    master_ids = []
    with open(my_files[0], 'r') as f:
        lines = f.readlines()
    i = -1
    for l in lines:

        if l[0] == '%':
            # check if there is already some lines inserted
            if i >= 0:
                if master_seq[i][-1][-1] == '\n':
                    removed = master_seq[i][-1][:-2]
                else:
                    removed = master_seq[i][-1][:-1]
                master_seq[i].remove(master_seq[i][-1])
                master_seq[i].append(removed)  # we remove the last letter

            master_ids.append(l)
            i += 1
            first_line = 1
        else:
            if first_line:
                master_seq[i].append(l[1:])  # remove the last letter
                first_line = 0
            else:
                master_seq[i].append(l)
    # last sequence
    if master_seq[i][-1][-1] == '\n':
        removed = master_seq[i][-1][:-2]
    else:
        removed = master_seq[i][-1][:-1]
    master_seq[i].remove(master_seq[i][-1])
    master_seq[i].append(removed)  # we remove the last letter

    # create
    # Now we open the rest of the files

    for file in my_files[1:]:
        temp_seq = [[] for i in range(len(my_sequences))]
        temp_ids = []
        with open(file, 'r') as f:
            lines = f.readlines()
            # file open, we read alignments
        i = -1
        for l in lines:
            if l[0] == '%':
                # check if there is already some lines inserted
                if i >= 0:
                    if file is not my_files[-1]:
                        if temp_seq[i][-1][-1] == '\n':
                            removed = temp_seq[i][-1][:-2]
                        else:
                            removed = temp_seq[i][-1][:-1]
                    else:
                        if temp_seq[i][-1][-1] == '\n':
                            removed = temp_seq[i][-1][:-1]
                    temp_seq[i].remove(temp_seq[i][-1])
                    temp_seq[i].append(removed)  # we remove the last letter (T) we inserted before to force the alignment

                temp_ids.append(l)
                i += 1
                first_line = 1
            else:
                if first_line:
                    first_line = 0
                    temp_seq[i].append(l[1:])
                else:
                    temp_seq[i].append(l)

        if file is not my_files[-1]:
            if temp_seq[i][-1][-1] == '\n':
                removed = temp_seq[i][-1][:-2]
            else:
                removed = temp_seq[i][-1][:-1]
        else:
            if temp_seq[i][-1][-1] == '\n':
                removed = temp_seq[i][-1][:-1]

        temp_seq[i].remove(temp_seq[i][-1])
        temp_seq[i].append(removed)  # we remove the last letter
        # for each sequences, we check if the ids match
        j = 0
        for master_id in master_ids:
            check = 0
            i = 0
            for temp_id in temp_ids:
                if master_id == temp_id:
                    for ts in temp_seq[i]:
                        master_seq[j].append(ts)
                    check = 1
                i += 1
            if check == 0:  # no sequences was inserted
                print('no sequence inserted in {} for file {}'.format(master_id, file))
            j += 1

    ############ write the file
    file = os.path.join(msa_folder, 'master_alignment2.aln')
    with open(file, 'w') as f:
        for s, i in zip(master_seq, master_ids):
            f.write('\n')
            f.write(i)
            f.write('\n')
            for ss in s:
                f.write(ss)

    return master_seq, master_ids,score_file

def get_first_milestone_appearance ( my_sequences,ids):
    # Get first milestone appearance
    milestones_seqs = [[] for i in range(np.size(my_sequences))]
    idx = 0
    for seq in my_sequences:
        # TODO: get collection of letter that appears in ALL successful sequences
        check = {'S': 0, 'P': 0, 'V': 0, 'L': 0, 'N': 0, 'A': 0, 'F': 0, 'Y': 0, 'Q': 0, 'K': 0, 'E': 0}

        for a in seq:

            if a in check and check[a] == 0:
                milestones_seqs[idx].append(a)
                check[a] = 1
        idx += 1

    # get number of last milestones
    last_milestone_achieve = {'S': 0, 'P': 0, 'V': 0, 'L': 0, 'N': 0, 'A': 0, 'F': 0, 'Y': 0, 'Q': 0, 'K': 0, 'E': 0}

    max_number_list = []
    for seq in milestones_seqs:
        last_milestone_achieve[seq[-1]] += 1
        max_number_list.append(len(seq))
    max_number = int(np.quantile(max_number_list,q=0.9))
    last_milestone = max(last_milestone_achieve.items(), key=operator.itemgetter(1))[0]

    # get sequences with the highest number of milestones
    temp_list = milestones_seqs.copy()
    for seq in temp_list:
        if len(seq) < max_number:
            milestones_seqs.remove(seq)

    # remove sequences wich do not finish with this last milestone
    # Cluster sequences based on first_milestone_apperitemseance
    clusters = [[] for i in milestones_seqs]
    i = 0
    for si in milestones_seqs:
        j = 0
        for sj in milestones_seqs:
            cluster_them = 1
            if np.size(sj) != np.size(si):
                cluster_them = 0
            for ai, aj in zip(si, sj):
                if ai != aj:
                    cluster_them = 0
                    break
            if cluster_them:
                clusters[i].append([j])
            j += 1
        i += 1

    # We get the milestones sequences of the greater class
    cl = np.argmax([np.size(s) for s in clusters])
    my_final_cluster = cl

    # Write down sequences in the cluster
    directory = 'my_clustered_sequences'
    if not os.path.exists(directory):
        os.mkdir(directory)

    my_clustered_seqs = []
    my_clustered_ids = []

    for s in clusters[my_final_cluster]:
        s = s[0]
        file = '{}/{}.fasta'.format(directory, s)
        with open(file, 'w') as f:
            f.write(ids[s])
            f.write(my_sequences[s])
            my_clustered_seqs.append(my_sequences[s])
            my_clustered_ids.append(ids[s])

    my_clustered_ids_int = np.arange(len(my_clustered_ids))

    return milestones_seqs,my_clustered_seqs,my_clustered_ids, my_clustered_ids_int


############################################
def clean_seq(myseq):
    temp_seq = ''.join([m for m in myseq])
    temp2 = []
    for c in temp_seq:
        if c != '\n':
            temp2.append(c)
    temp_seq = ''.join([m for m in temp2])
    return temp_seq


############################################
def get_consensus(master, threshold=0.8):
    my_consensus = []
    n_sequences = len(master)

    cleaned_sequences = []
    for i in range(n_sequences):
        cleaned_sequences.append(clean_seq(myseq=master[i]))
    max_lenght = [len(m) for m in cleaned_sequences]

    for i in range(np.max(max_lenght)):
        con = 0
        c = '-'
        for k in range(0, n_sequences):
            if cleaned_sequences[k][i] != '-':
                c = cleaned_sequences[k][i]
        if c != '-':
            for j in range(0, n_sequences):
                if c == cleaned_sequences[j][i]:
                    con += 1
            if con / n_sequences >= threshold:
                my_consensus.append(c)
    return clean_seq(my_consensus)




def compute_consensus_boostraping (aa_dict, reward_dict_ids, my_clustered_seqs,my_clustered_ids, milestones_seqs,score_file,
                                   fasta_file,outfile='consensus.fasta',output_dir='tmp',MAX_ITER=10,debug=False):

    N = 8
    if debug:
        N= len(my_clustered_ids)
        MAX_ITER = N

    my_clustered_ids_int = np.arange(len(my_clustered_ids))
    with open(outfile, 'w') as f:
        for j in range(MAX_ITER):
            boos_seq = []
            boos_ids = []
            boos_ids_int = []
            np.random.seed(seed=j)
            selection = np.random.choice(my_clustered_ids_int, N, replace=False)

            for i in selection:
                boos_seq.append(my_clustered_seqs[i])
                boos_ids.append(my_clustered_ids[i])
                boos_ids_int.append(my_clustered_ids_int[i])


            temp_master_seq, temp_master_ids, _ = get_alignment(aa_dict, reward_dict_ids, boos_seq, boos_ids, boos_ids_int, milestones_seqs,output_dir)

            consensus_seq = get_consensus(temp_master_seq,threshold=0.8)

            f.write(">{}\n{}\n".format(j, consensus_seq))


    infile = outfile
    dir = os.path.dirname(os.path.abspath(outfile))
    msa_outfile = os.path.join(dir,"msa_consensus_{}.aln".format(path.basename(infile)[:-6]))
    msatree = "msa_consensus.dnd"

    # collect all subsequences
    sequences_for_scoring = []
    with open(infile, 'r') as f:
        data = f.read().split()
        for d in data:
            if d[0] != '>':
                sequences_for_scoring.append(d)

    scoring_matrix, score_file , _ = create_scoring_matrix(aa_dict, reward_dict_ids, outdir=output_dir,
                                                           fasta_sequences=sequences_for_scoring, offdiag=-10.0,
                                                           offdiag_gap=-1.0, subtask='consensus_score')
    # get maximum
    with open(score_file, 'r') as f:
        data = f.read().split()
        floats = []
        for elem in data:
            try:
                if float(elem) < 10:
                    floats.append(float(elem))
            except ValueError:
                pass

    try:
        gap_open = np.max(floats)/100
        gap_ext = 0
    except ValueError:
        print("Error in gap penalties selection")
        gap_open = 0.01
    # -CLUSTERING=UPGMA
    cmd = "clustalw2 -ALIGN -NEGATIVE -CLUSTERING=UPGMA " \
          "-INFILE={infile} " \
          "-OUTFILE={outfile} " \
          "-PWMATRIX={scores} -PWGAPOPEN={gapopen} -PWGAPEXT={gapext} " \
          "-MATRIX={scores} -GAPOPEN={gapopen} -GAPEXT={gapext} -CASE=UPPER " \
          "-GAPDIST=0 -NOPGAP -NOHGAP -MAXDIV=0 -ENDGAPS -NOVGAP " \
          "-NEWTREE={tree} -TYPE=PROTEIN -OUTPUT=GDE".format(
        infile=infile, outfile=msa_outfile, scores=score_file, gapopen=gap_open, gapext=gap_ext,
        tree=msatree
    )

    output = subprocess.run(cmd.split(" "), stdout=subprocess.PIPE).stdout.decode("utf-8")

    # open file
    with open(msa_outfile, 'r') as f:
        lines = f.readlines()

    consensus = [[] for i in range(MAX_ITER)]
    first_line = 0
    i = -1
    for l in lines:

        if l[0] == '%':
            i += 1
            first_line = 1
        else:
            if first_line:
                consensus[i].append(l)  # remove the last 2
                first_line = 0
            else:
                consensus[i].append(l)  # remove the last 2

    # clean consensus
    temp_list = []
    for con in consensus:
        temp_con = con
        temp_con[0] = con[0].replace('T', '').replace('\n', '')
        temp_list.append(temp_con)

    print('90% consensus')
    print(get_consensus(temp_list, threshold=0.9))
    print('49% consensus (51%maximal, mayority voting)')
    print(get_consensus(temp_list, threshold=0.2))
    print('0% consensus (maximal seq)')
    print(get_consensus(temp_list, threshold=0.0))

    ## Final consensus:
    final_consensus_seq = get_consensus(temp_list, threshold=0.2)


    return final_consensus_seq, temp_list

def get_crafting_items(s_act, s_inv, aa_dict):
    crafting_items = []
    for a in s_act:
        if a in ['craft','nearbySmelt','nearbyCraft']:
            for item in s_act[a]['values']:
                if item is not 'none':
                    aa = aa_dict[s_inv.index(item)]
                    if not crafting_items.count(aa) :
                        crafting_items.append(aa)

    return crafting_items

def get_transitions_from_consensus(aa_dict, consensus,crafting_items): #####
    first = 'None'
    rudder_transitions = []
    crafting_transitions = []
    concatenated_crafting_transitions = []


    # concatenate crafting_transitions:
    last_cr = 0
    old_first = 0
    last_tra = None
    for c in consensus:
        index = list(aa_dict.values()).index(c)
        tra = [first, index]
        if c in crafting_items:
            if last_cr:
                tra = [old_first, index]
                if last_tra is not None:
                    concatenated_crafting_transitions.remove(last_tra)
                    concatenated_crafting_transitions.append(tra)
                last_cr = 1

            else:
                concatenated_crafting_transitions.append(tra)
                last_cr = 1
                old_first = first
            last_tra = tra
        else:
            last_cr = 0

        first = index
        last_tra = tra
    # normal transitions extractor
    first = 'None'
    for c in consensus:
        index = list(aa_dict.values()).index(c)
        tra = [first,index]
        if c in crafting_items:
            if not crafting_transitions.count(tra):
                crafting_transitions.append(tra)
        else:
            if not rudder_transitions.count(tra):
                rudder_transitions.append(tra)
        first = index

    return rudder_transitions,crafting_transitions,concatenated_crafting_transitions

    #################################
def no_other_action(data,iter_pos):
    for f in data['action_dict'].keys():
        if  f != 'camera':
            if data['action_dict'][f][iter_pos] != 0:
                return False
    return True


def add_triplet(list_of_triplets, triplet):

    if list_of_triplets == []:
        return [True,0]
    else:
        i=0
        for t in list_of_triplets:
            if t['action'] == triplet['action'] and t['results'] == triplet['results'] and t['item'] == triplet['item']:
                return [False, i]
            i += 1
    return [True,0]



def get_triplets(rudder_samples):
    # move into another function
    window_to_add = 5
    list_of_triplets = []
    for ke,samples in tqdm(rudder_samples.items()):
        sa_dict = samples['action_dict']
        observation_inventory =  np.transpose([np.array(samples['observation_dict']['inventory'][d]) for d in samples['observation_dict']['inventory'] ])

        for act in ['craft', 'nearbyCraft', 'nearbySmelt']:
            triplet_acit = dict()
            triplet_acit['action'] = act
            took_action = np.where(sa_dict[act] > 0)[0]
            if took_action.size > 0:
                for pos in took_action:
                    triplet_acit['item'] = sa_dict[act][pos]
                    range_to = min(pos + 1 + window_to_add, sa_dict[act].size - 1)
                    i = pos + 1
                    for i in range(pos + 1, range_to):
                        if not no_craft_action(sa_dict, i):
                            break
                    if i > sa_dict[act].size - 1:
                        continue
                    # get differences on the window
                    diff = observation_inventory[i] - observation_inventory[pos]
                    object_created = np.where(diff > 0)[0]
                    if object_created.size == 1:
                        triplet_acit['results'] = object_created[0]
                        triplet_acit['min'] = diff[object_created[0]]
                        # check if the triplet is already inserted
                        add_triplet_bool, number = add_triplet(list_of_triplets, triplet_acit)
                        if add_triplet_bool:
                            list_of_triplets.append(triplet_acit.copy())
                        else:
                            list_of_triplets[number]['min'] = np.min(
                                [list_of_triplets[number]['min'], triplet_acit['min']])

                    elif object_created.size > 1:
                        print("12345 Not implemented")

    return list_of_triplets
########################
def get_frames_and_actions_from_datadict(datasetdict,file_dir,startidx,endidx,fix_for_dataloader=False):
    name = os.path.split(file_dir)[-1]

    if  fix_for_dataloader:
        if  endidx +1 < len(datasetdict[name]['observation_dict']['pov']):
            frames = datasetdict[name]['observation_dict']['pov'][startidx:endidx+1]
            actions = dict()
            for k,v in datasetdict[name]['action_dict'].items():
                actions[k] = v[startidx:endidx+1]

        else:
            frames = datasetdict[name]['observation_dict']['pov'][startidx:endidx]
            actions = dict()
            for k,v in datasetdict[name]['action_dict'].items():
                actions[k] = v[startidx:endidx]
    else:
        frames = datasetdict[name]['observation_dict']['pov'][startidx:endidx]
        actions = dict()
        for k, v in datasetdict[name]['action_dict'].items():
            actions[k] = v[startidx:endidx]


    return frames,actions

def get_initial_state_from_datadict(datasetdict,file_dir,startidx):
    name = os.path.split(file_dir)[-1]
    inventory = datasetdict[name]['observation_dict']['inventory']

    my_state = dict()
    for k,v in inventory.items():
        my_state[k] = v[startidx]

    return my_state


def prepare_data(rootdir,prepare):
    ############################
    # TODO: data loader

    dataset = MineRLBaseDataset(root=rootdir, download=False, prepare=prepare, experiment="MineRLObtainDiamond-v0",
                                data_split=1, include_metadata=True)
    datasetdict = dict()
    for seq in dataset.sequences:
        f = seq[5]['stream_name']
        datasetdict[f] = dict()
        datasetdict[f]['observation_dict'] = seq[0]
        datasetdict[f]['action_dict'] = seq[1]
        datasetdict[f]['reward_seq'] = seq[2]
        datasetdict[f]['done_seq'] = seq[4]
        datasetdict[f]['meta'] = seq[5]

    return datasetdict

def extract_substask_sequences(samples ,datasetdict,  type_seq = 'learning', outpath=None,count_actions_dict = None):


    keep_length = 0
    list_of_meta_info = []
    for sample in tqdm(samples, total=len(samples), leave=False):
        numpy_path, startidx, endidx, n_act, counting, ranking = sample
        file_dir = path.dirname(numpy_path)
        # --
        all_frames, all_actions = get_frames_and_actions_from_datadict(datasetdict,file_dir, startidx, -1)

        start_offset = np.abs(len(all_frames) - len(all_actions['attack']))
        start_offset_end = start_offset

        # this function returns +1 (minerl data loader then removes the last State action)
        if endidx == -1:
            endidx = len(all_actions['attack'])
            start_offset_end = 0

        frames, actions = get_frames_and_actions_from_datadict(datasetdict, file_dir,startidx + start_offset, endidx + start_offset_end,fix_for_dataloader=True)
        # get sum of all actions
        # count_actions_dict

        seqname = path.basename(file_dir)
        seq_out_dir = path.join(outpath, "{}-{}".format(seqname,int(counting)) )
        if not path.exists(seq_out_dir):
            os.mkdir(seq_out_dir)

        # frames
        if type_seq == 'rudder':
            #imageio.mimwrite(seq_out_dir + "/recording.mp4", frames, fps=20)
            np.savez_compressed(seq_out_dir + "/recording.npz", frames=frames)
        # metadata
        metadata = dict()
        metadata['initial_state'] = get_initial_state_from_datadict(datasetdict,file_dir, startidx)
        metadata["numpy_path"] = os.path.split(os.path.split(numpy_path)[0])[-1] #relative path instead of absolute. name of the sequence
        metadata["startidx"] = int(startidx)
        metadata["endidx"] = int(endidx)
        metadata["n_act"] = int(n_act)
        metadata["counting"] = int(counting)
        metadata["ranking"] = int(ranking)

        sa = np.load(numpy_path)
        sa_dict = {}

        if endidx+1 <  len(sa[sa.files[0]]):
            for file in sa.files:
                sa_dict[file] = sa[file][startidx:endidx+1] # JAM before -> sa[file][startidx:endidx - 1]
        else:
            for file in sa.files:
                sa_dict[file] = sa[file][startidx:endidx]

        # Overwrite with the corrected data
        for k,v in sa_dict.items():
            if 'action' in k and k!='acttion_camera':
                sa_dict[k] = actions[k.replace('action_','')]
            # make it int
        for k, v in metadata['initial_state'].items():
            metadata['initial_state'][k] = int(v)
        with open(seq_out_dir + "/metadata.json", 'w') as o_f:
            json.dump(metadata, fp=o_f)
            print(seq_out_dir + "/metadata.json")

        np.savez_compressed(seq_out_dir + "/rendered.npz", **sa_dict)

        for k, v in sa_dict.items():
            if 'action' in k:
                if k != 'action_camera':
                    k = k.replace('action_', '')
                    if np.size(count_actions_dict[k])>1:
                        for i,_ in enumerate(count_actions_dict[k]):
                            if i> 0:
                                count_actions_dict[k][i] += int(np.sum(v==i))
                    else:
                        count_actions_dict[k] += int(np.sum(v==1))
        keep_length += (endidx - startidx)

        list_of_meta_info.append(metadata)
    #Save meta_info_file

    new_path = os.path.split(outpath)[0]
    with open(os.path.join(new_path, 'metainfo.json'), 'w') as o_f:
        json.dump(list_of_meta_info, fp=o_f)

    # return dictionary with action count and sequence length
    count_actions_dict['total_length'] = int(keep_length)
    return True,count_actions_dict

def count_all_actions(file, start, end):
    data = np.load(file, allow_pickle=True)
    files = [f for f in data.files if "action_" in f and 'camera' not in f ]
    actions = data[files[0]].astype(np.float64)

    for i in range(1, len(files)):
        if actions.shape == data[files[i]].shape:
            actions += data[files[i]]
    actions = np.minimum(actions, 1)
    return int(np.sum(actions[start:end]))

def count_non_crafting_actions(file, start, end):
    data = np.load(file, allow_pickle=True)
    files = [f for f in data.files if "action_" in f and ("craft" not in f.lower() or "smelt" not in f.lower())]
    actions = data[files[0]].astype(np.float64)

    for i in range(1, len(files)):
        if actions.shape == data[files[i]].shape:
            actions += data[files[i]]
    actions = np.minimum(actions, 1)
    return np.sum(actions[start:end])

def get_list_of_sequences_transitions_index(inventory_changes, spaces_mapper, rudder_transitions,type_= 'RUDDER',overlap=0):
    rudder_samples = {}
    last_transition = {}
    # create a dictionary to track the appearance of a certain item
    rudder_transitions_count = dict()
    for f in inventory_changes.files:
        rudder_transitions_count[f] = dict()
        for transition in rudder_transitions:
            if transition[0] is 'None':
                cur_item_name = transition[0] + '-' + spaces_mapper.inventory_to_key(
                    transition[1])
            else:
                cur_item_name = spaces_mapper.inventory_to_key(transition[0]) + '-' + spaces_mapper.inventory_to_key(
                    transition[1])
            if cur_item_name not in rudder_transitions_count[f].values():
                rudder_transitions_count[f][cur_item_name] = 0


    for transition in rudder_transitions:
        if transition[0] is 'None':
            cur_item_name = transition[0] + '-' + spaces_mapper.inventory_to_key(
                transition[1])
        else:
            cur_item_name = spaces_mapper.inventory_to_key(transition[0]) + '-' + spaces_mapper.inventory_to_key(
                transition[1])
        cur_item_transitions = []
        for f, seq in inventory_changes.items():
            keyframes, items = seq
            start, end = 0, 0
            if transition[0] is 'None':
                tmp = np.where(items == transition[1])[0]
                if len(tmp) > 0:
                    end = int(keyframes[tmp[0]] + 3) # small offset to get delays in inventory
                    cur_item_transitions.append([f, start, end, count_all_actions(f, start, end),rudder_transitions_count[f][cur_item_name]])
                    rudder_transitions_count[f][cur_item_name] += 1

            else:
                i = 0
                while i < len(items):
                    done = False
                    if items[i] == transition[0]:
                        # Add small offset to get early actions that didn't show yet in the inventory
                        if keyframes[i] > 3 and type_ is not 'atomic_craft':
                            start = int(keyframes[i] - 3)
                        else:
                            start = int(keyframes[i])
                        # Add an extra overlapping.
                        if type_ == 'RUDDER' and overlap > 0:
                            if start > overlap:
                                start = int(start - overlap)
                            else:
                                start = int(start)
                        found_end = False
                        for j in range(i + 1, len(items)):
                            # for imitation sequences
                            if items[j] == transition[0] and transition[0] != transition[1]:
                                # Move the start position. For example, move to the next S in the seq. SSSP if the transition is S-P
                                if keyframes[j] > 3 and type_ is not 'atomic_craft':
                                    start = int(keyframes[j] - 3)
                                else:
                                    start = int(keyframes[j])
                                # Add an extra overlapping.
                                if type_ == 'RUDDER' and overlap > 0:
                                    if start > overlap:
                                        start = int(start - overlap)
                                    else:
                                        start = int(start)
                            # for statistics
                            if items[j] == transition[1]:
                                if type_ is 'statistics' and transition[0] == transition[1]:
                                    if j < len(items)-1 and items[j+1] == transition[1]:
                                        # we pass to the next
                                        continue
                                    else:
                                        if type_ is not 'atomic_craft':
                                            end = int(keyframes[j] + 3)  # small offset to get delays in inventory
                                        else:
                                            end = int(keyframes[j])
                                        cur_item_transitions.append([f, start, end, count_all_actions(f, start, end),
                                                                     rudder_transitions_count[f][cur_item_name]])
                                        i = j
                                        done = True
                                        rudder_transitions_count[f][cur_item_name] += 1
                                else:
                                    if type_ is not 'atomic_craft':
                                        end = int(keyframes[j] + 3) # small offset to get delays in inventory
                                    else:
                                        end = int(keyframes[j])
                                    cur_item_transitions.append([f, start, end, count_all_actions(f, start, end),rudder_transitions_count[f][cur_item_name]])
                                    i = j
                                    done = True
                                    rudder_transitions_count[f][cur_item_name] += 1
                                    break
                    if not done:
                        i += 1
            # memorize last transition end per demonstration
            last_transition[f] = end
        if cur_item_name in rudder_samples:
            rudder_samples[cur_item_name].extend(cur_item_transitions)
        else:
            rudder_samples[cur_item_name] = cur_item_transitions

    # append everything after last transition for each sequence
    if type_ == 'RUDDER':
        rudder_samples["end"] = []
        for f, end in last_transition.items():
            rudder_transitions_count[f]["end"] = 0
        for f, end in last_transition.items():
            if end > 0:
                rudder_samples["end"].append([f, int(end + 1), -1, count_all_actions(f, int(end + 1), -1),rudder_transitions_count[f]["end"]])
                rudder_transitions_count[f]["end"] += 1

    return rudder_samples

def get_metadata_score(dirname: str):
    """Return score (= -1 * duration_steps) for successful human demonstration sequence in dirname,
    return None if success==False

    Parameters
    ----------
    dirname: str
        Path to 'metadata.json' file

    Returns
    ----------
    score: int
        Demonstration score, which is (-1 * duration_steps) on success and None if not successful
    """
    meta_data_file_name = 'metadata.json'
    with open(os.path.join(dirname, meta_data_file_name), 'r') as f:
        metadata = json.load(f)
    if metadata['total_reward']>0:
        return metadata['total_reward']
    else:
        return None
