import torch
import numpy as np
import os
from copy import copy
from torch.utils.data import Dataset
from utils import namenum, summary
from ete3 import TreeNode
from itertools import combinations

def reroot(tree):
    assert isinstance(tree, TreeNode)
    for (node1, node2) in combinations([0,1,2],r=2):
        common_ancestor = tree.get_common_ancestor(tree.search_nodes(name=node1)[0], tree.search_nodes(name=node2)[0])
        if not common_ancestor.is_root():
            tree.set_outgroup(common_ancestor)
            c1 = common_ancestor.children[0]
            c2 = common_ancestor.children[1]
            tree.remove_child(common_ancestor)
            tree.add_child(c1)
            tree.add_child(c2)
    return tree

def name_transformer(tree):
    assert isinstance(tree, TreeNode)
    for node in tree.traverse('postorder'):
        if node.is_leaf():
            if node.name > 3:
                node.name = 2 * node.name - 3
            node.used = False
        elif not node.is_root():
            candidates = []
            for leaf in node.get_leaves():
                if not leaf.used:
                    candidates.append(leaf.name)
            node.name = max(candidates) + 1
            node.search_nodes(name=max(candidates))[0].used = True
        else:
            node.name = 2 * len(node.get_leaves()) - 3
    return tree

def get_decisions(tree):
    assert isinstance(tree, TreeNode)
    decisions = []
    for name in range(2*len(tree.get_leaves())-5, 2, -2):
        leaf = tree.search_nodes(name=name)[0]
        sister = leaf.get_sisters()[0]
        parent = leaf.up 
        grandparent = parent.up
        grandparent.remove_child(parent)
        grandparent.add_child(sister)
        decisions.append(sister.name)
    decisions.reverse()
    return decisions

def process_empFreq_transformer(dataset):
    ground_truth_path, samp_size = 'data/raw_data_DS1-8/', 750001
    tree_dict_total, tree_names_total, tree_wts_total = summary(dataset, ground_truth_path, samp_size=samp_size)
    emp_tree_freq = {tree_dict_total[tree_name]:tree_wts_total[i] for i, tree_name in enumerate(tree_names_total)}
    wts = list(emp_tree_freq.values())
    taxa = sorted(list(emp_tree_freq.keys())[0].get_leaf_names())
    ntips = len(taxa)
    path = os.path.join('tfdecisions', dataset,'emp_tree_freq')
    os.makedirs(path, exist_ok=True)
    np.save(os.path.join(path, 'wts.npy'), wts)
    np.save(os.path.join(path, 'taxa.npy'), taxa)
    decisions_tensor = []
    for tree in emp_tree_freq.keys():
        namenum(tree, taxa)
        tree_cp = tree.copy()
        tree_cp = reroot(tree_cp)
        tree_cp = name_transformer(tree_cp)
        decisions = get_decisions(tree_cp)
        decisions_tensor.append(decisions)
    torch.save(torch.LongTensor(decisions_tensor), os.path.join(path, 'decisions.pt'))


class EmbedData(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.path = os.path.join('..', 'tfdecisions',dataset,'emp_tree_freq')
        self.data = torch.load(os.path.join(self.path, 'decisions.pt'))
        self.wts = np.load(os.path.join(self.path, 'wts.npy'))
        self.length = len(self.wts)

    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return self.length
    
class EmbedDataLoader(object):
    def __init__(self, dataset, batch_size, wts=None) -> None:
        self.dataset = dataset
        self.batch_size = batch_size
        self.wts = wts
        self.random = True if isinstance(self.wts, np.ndarray) else False
        self.length = self.dataset.__len__()
        self.position = 0
    
    def nonrandomize(self):
        self.random = False

    def randomize(self):
        self.random = True
        
    def initialize(self):
        self.position = 0

    def next(self):
        if self.random:
            indexes = np.random.choice(self.length, size=self.batch_size, replace=True, p=self.wts)
        else:
            if self.position + self.batch_size <= self.length:
                indexes = list(range(self.position, self.position+self.batch_size))
                self.position += self.batch_size
            elif self.position < self.length:
                indexes = list(range(self.position, self.length))
                self.position = self.length
            else:
                raise StopIteration
        return self.fetch(indexes)

    def fetch(self, indexes):
        samples = []
        for i in indexes:
            samples.append(self.dataset.__getitem__(i))
        return torch.stack(samples)

def get_empdataloader(dataset, batch_size=10):
    data = EmbedData(dataset)
    empdataloader = EmbedDataLoader(data, batch_size=batch_size, wts=None)
    return empdataloader