import torch
import torch.nn as nn
import time
import math
import numpy as np
from ete3 import TreeNode
from Bio.Phylo.TreeConstruction import *
import psutil
import os
from phyloModel import PHY
from gnn_Model import GNN_BranchModel
from transformer import Transformer
import gc


def mp_add(args):
    tree, taxon, pos, name_dict = args
    assert isinstance(tree, TreeNode)
    node_to_add = TreeNode(name=2*taxon-3)
    anchor_node = pos
    parent = anchor_node.up

    parent.remove_child(anchor_node)
    newparent = TreeNode(name=2*taxon-2)
    newparent.add_child(anchor_node)
    newparent.add_child(node_to_add)
    parent.add_child(newparent)

    tree.name = 2*taxon-1
    name_dict[2*taxon-3] = node_to_add
    name_dict[2*taxon-2] = newparent
    name_dict[2*taxon-1] = tree
    return tree, name_dict

def mp_renamenum(args):
    tree, level = args
    j = level
    for node in tree.traverse("postorder"):
        if node.is_leaf():
            if node.name > 2:
                node.name = (node.name+3)//2
        else:
            node.name, j = j, j+1
    return tree

class VBPIbase(nn.Module):
    EPS = np.finfo(float).eps
    def __init__(self, ntips, emp_tree_freq=None,
                hidden_dim_tree=100, nheads=4):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.emp_tree_freq = emp_tree_freq
        self.ntips = ntips
        self.log_p_tau = - np.sum(np.log(np.arange(3, 2*self.ntips-3, 2)))

        self.tree_model = Transformer(self.ntips, hidden_dim=hidden_dim_tree, n_head=nheads, device=self.device).to(self.device)

    def init_tree(self):
        name_dict = {}
        tree = TreeNode(name=3)
        name_dict[3] = tree
        for i in [0,1,2]:
            node = TreeNode(name=i)
            tree.add_child(node)
            name_dict[i] = node
        return tree, name_dict
    
    def sample_trees(self, n_particles, eps=0.0):
        logq_tree = 0.0
        node_features, edge_index = self.tree_model._init(n_particles)
        trees, name_dicts = zip(*[self.init_tree() for _ in range(n_particles)])
        trees, name_dicts = list(trees), list(name_dicts)
        for taxon in range(3, self.ntips):
            logits, new_feature = self.tree_model(node_features, edge_index, taxon)
            log_prob = torch.log(logits.exp() * (1-eps) + torch.ones_like(logits) / logits.shape[0] * eps)
            pos = torch.multinomial(input=log_prob.exp(), num_samples=1)
            logq_tree += torch.gather(log_prob, dim=1, index=pos).squeeze(-1)
            node_features, edge_index = self.tree_model.update(node_features, edge_index, new_feature, pos)
            pos = pos.squeeze(-1)
            anchor_nodes = [name_dicts[i][pos[i].item()] for i in range(len(name_dicts))]
            trees, name_dicts = zip(*map(mp_add, list(zip(trees, [taxon]*n_particles, anchor_nodes, name_dicts))))
            trees, name_dicts = list(trees), list(name_dicts)
        trees = map(mp_renamenum, list(zip(trees, [self.ntips]*n_particles)))
        return list(trees), logq_tree
    

    def tree_prob(self, batch):
        logprob = 0.0
        node_features, edge_index = self.tree_model._init(batch.shape[0])
        for i in range(3, self.ntips):
            pos = batch[:,i-3].to(self.device)
            logits, new_feature = self.tree_model(node_features, edge_index, i)
            logprob += torch.gather(logits, dim=1, index=pos[:,None]).squeeze()
            # print(i, torch.gather(logits, dim=1, index=pos[:,None]).squeeze())
            node_features, edge_index = self.tree_model.update(node_features, edge_index, new_feature, pos[:,None])
        return logprob

    @torch.no_grad()
    def kl_div(self):
        kl_div = 0.0
        probs = []
        negDataEnt = np.sum(self.emp_tree_freq.dataset.wts * np.log(np.maximum(self.emp_tree_freq.dataset.wts, self.EPS)))
        self.emp_tree_freq.initialize()
        for i in range(0, self.emp_tree_freq.dataset.length, self.emp_tree_freq.batch_size):
            batch = self.emp_tree_freq.next()
            tree_prob = self.tree_prob(batch).exp().tolist()
            if isinstance(tree_prob, list):
                probs.extend(tree_prob)
            elif isinstance(tree_prob, float):
                probs.append(tree_prob)
            else:
                raise TypeError
        kl_div = negDataEnt - np.sum(self.emp_tree_freq.dataset.wts * np.log(np.maximum(probs, self.EPS)))
        return kl_div, probs
    
    
class VBPI(VBPIbase):
    def __init__(self, taxa, data, pden, subModel, emp_tree_freq=None,
                 scale=0.1, hidden_dim_tree=100, hidden_dim_branch=100, num_layers_branch=2, gnn_type='edge', aggr='sum',project=False, nheads=4):
        super(VBPI, self).__init__(ntips=len(data), emp_tree_freq=emp_tree_freq,
                hidden_dim_tree=hidden_dim_tree, nheads=nheads)
        self.phylo_model = PHY(data, taxa, pden, subModel, scale=scale, device=self.device)  ## the unnormalized posterior density.
        self.branch_model = GNN_BranchModel(self.ntips, hidden_dim_branch, num_layers=num_layers_branch, gnn_type=gnn_type, aggr=aggr, project=project, device=self.device).to(device=self.device)
        self.scale = scale
        self.taxa = taxa

    @torch.no_grad()
    def lower_bound(self, n_particles=1, n_runs=1000):
        lower_bounds = []
        for run in range(n_runs):
            samp_trees, logq_tree = self.sample_trees(n_particles)
            samp_log_branch, logq_branch = self.branch_model(samp_trees)
            logll = torch.stack([self.phylo_model.loglikelihood(log_branch, tree) for log_branch, tree in zip(*[samp_log_branch, samp_trees])])
            logp_prior = self.phylo_model.logprior(samp_log_branch)   
            lower_bounds.append(torch.logsumexp(logll + logp_prior - logq_tree - logq_branch + self.log_p_tau - math.log(n_particles), 0))            
        lower_bound = torch.stack(lower_bounds).mean()
            
        return lower_bound.item()

    def rws_lower_bound(self, inverse_temp=1.0, n_particles=10, eps=0.0):
        samp_trees, logq_tree = self.sample_trees(n_particles, eps=eps)
        samp_log_branch, logq_branch = self.branch_model(samp_trees)
        logll = torch.stack([self.phylo_model.loglikelihood(log_branch, tree) for log_branch, tree in zip(*[samp_log_branch, samp_trees])])
        logp_prior = self.phylo_model.logprior(samp_log_branch)
        logp_joint = inverse_temp * logll + logp_prior
        lower_bound = torch.logsumexp(logll + logp_prior - logq_tree - logq_branch + self.log_p_tau - math.log(n_particles), 0)
        
        l_signal = logp_joint - logq_tree.detach() - logq_branch
        temp_lower_bound = torch.logsumexp(l_signal - math.log(n_particles), dim=0)
        snis_wts = torch.softmax(l_signal, dim=0)
        rws_fake_term = torch.sum(snis_wts.detach() * logq_tree, dim=0)

        return temp_lower_bound, rws_fake_term, lower_bound, torch.max(logll)

    def vimco_lower_bound(self, inverse_temp=1.0, n_particles=10, eps=0.0):
        samp_trees, logq_tree = self.sample_trees(n_particles, eps=eps)
        samp_log_branch, logq_branch = self.branch_model(samp_trees)

        logll = torch.stack([self.phylo_model.loglikelihood(log_branch, tree) for log_branch, tree in zip(*[samp_log_branch, samp_trees])])
        logp_prior = self.phylo_model.logprior(samp_log_branch)
        logp_joint = inverse_temp * logll + logp_prior
        lower_bound = torch.logsumexp(logll + logp_prior - logq_tree - logq_branch + self.log_p_tau - math.log(n_particles), 0)
        
        l_signal = logp_joint - logq_tree - logq_branch
        mean_exclude_signal = (torch.sum(l_signal) - l_signal) / (n_particles-1.)
        control_variates = torch.logsumexp(l_signal.view(-1,1).repeat(1, n_particles) - l_signal.diag() + mean_exclude_signal.diag() - math.log(n_particles), dim=0)
        temp_lower_bound = torch.logsumexp(l_signal - math.log(n_particles), dim=0)
        vimco_fake_term = torch.sum((temp_lower_bound - control_variates).detach() * logq_tree, dim=0)

        return temp_lower_bound, vimco_fake_term, lower_bound, torch.max(logll)

    def learn(self, stepsz, maxiter=100000, test_freq=1000, lb_test_freq=5000, anneal_freq_tree=20000, anneal_freq_branch=20000, anneal_freq_tree_warm=20000, anneal_freq_branch_warm=20000, anneal_rate_tree=0.75, anneal_rate_branch=0.75, n_particles=10, init_inverse_temp=0.001, save_freq=1000, warm_start_interval=50000, method='vimco',  save_to_path=None, logger=None, clip_grad=False, clip_value=100.0, eps_max=0.0, eps_period=20000):
        lbs, lls = [], []
        test_kl_div, test_lb = [], []
        grad_norms = []
        if not isinstance(stepsz, dict):
            stepsz = {'tree': stepsz, 'branch': stepsz}
        optimizer_tree = torch.optim.Adam(params=self.tree_model.parameters(), lr=stepsz['tree'])
        optimizer_branch = torch.optim.Adam(params=self.branch_model.parameters(), lr=stepsz['branch'])
        scheduler_tree = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer_tree, milestones=list(range(anneal_freq_tree_warm, anneal_freq_tree_warm+warm_start_interval, anneal_freq_tree_warm)) + list(range(warm_start_interval+anneal_freq_tree, maxiter+anneal_freq_tree, anneal_freq_tree)), gamma=anneal_rate_tree)
        scheduler_branch = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer_branch, milestones=list(range(anneal_freq_branch_warm, anneal_freq_branch_warm+warm_start_interval, anneal_freq_branch_warm)) + list(range(warm_start_interval+anneal_freq_branch, maxiter+anneal_freq_branch, anneal_freq_branch)), gamma=anneal_rate_branch)

        run_time = -time.time()
        self.tree_model.train()
        for it in range(1, maxiter+1):
            inverse_temp = min(1., init_inverse_temp + it * 1.0/warm_start_interval)
            eps = eps_max * max(1. - it / eps_period, 0.0) 
            if method == 'vimco':
                temp_lower_bound, vimco_fake_term, lower_bound, logll  = self.vimco_lower_bound(inverse_temp, n_particles, eps)
                loss = - temp_lower_bound - vimco_fake_term
            elif method == 'rws':
                temp_lower_bound, rws_fake_term, lower_bound, logll = self.rws_lower_bound(inverse_temp, n_particles, eps)
                loss = - temp_lower_bound - rws_fake_term
            else:
                raise NotImplementedError

            lbs.append(lower_bound.item())
            lls.append(logll.item())
            
            optimizer_tree.zero_grad()
            optimizer_branch.zero_grad()
            loss.backward()
            grad_norm = nn.utils.clip_grad.clip_grad_norm_(parameters=self.parameters(), max_norm=clip_value if clip_grad else float('inf'), error_if_nonfinite=True)
            grad_norms.append(grad_norm.item())
            optimizer_tree.step()
            scheduler_tree.step()
            optimizer_branch.step()
            scheduler_branch.step()
            
            gc.collect()
            if it % test_freq == 0:
                run_time += time.time()
                logger.info('{} Iter {}:({:.3f}s) Lower Bound: {:.4f} | Loglikelihood: {:.4f} | GradNorm: Mean: {:.4f} Max: {:.4f} | Memory: {:.04f} MB'.format(time.asctime(time.localtime(time.time())), it, run_time, np.mean(lbs), np.max(lls), np.mean(grad_norms), np.max(grad_norms), psutil.Process(os.getpid()).memory_info().rss/1024/1024))
                if it % lb_test_freq == 0:
                    self.tree_model.eval()
                    run_time = -time.time()
                    test_lb.append(self.lower_bound(n_particles=1))
                    run_time += time.time()
                    if self.emp_tree_freq:
                        kldiv, pred_probs = self.kl_div()
                        test_kl_div.append(kldiv)
                        logger.info('>>> Iter {}:({:.1f}s) Test Lower Bound: {:.4f} Test KL: {:.4f}'.format(it, run_time, test_lb[-1], test_kl_div[-1]))
                    else:
                        logger.info('>>> Iter {}:({:.1f}s) Test Lower Bound: {:.4f}'.format(it, run_time, test_lb[-1]))
                    self.tree_model.train()
                    gc.collect()
                run_time = -time.time()
                lbs, lls = [], []
                grad_norms = []
            if it % save_freq == 0:
                torch.save(self.state_dict(), save_to_path.replace('final', str(it)))
        if save_to_path is not None:
            torch.save(self.state_dict(), save_to_path)
        return test_lb, test_kl_div

