from typing import Union, Optional, List, Set, Dict
import numpy as np
import test
from scipy import sparse as sp

import torch
from torch import nn

from .graph import Graph
from .generator_core import GraphConv, Agent

from typing import List, Tuple
tf = lambda x: torch.Tensor(x)

def csc_to_torch(csc_matrix):
    torch_list = []
    for item in csc_matrix :
        coo_object = item.tocoo()
        values = coo_object.data
        indices = np.vstack((coo_object.row, coo_object.col))
        i = torch.LongTensor(indices)
        v = torch.FloatTensor(values)
        shape = coo_object.shape
        torch_matrix = torch.sparse.FloatTensor(i, v, torch.Size(shape))
        torch_list+=[torch_matrix]
    return torch_list




def eos_to_number(current_state):
    if 'EOS' not in current_state:

        return current_state
    else:

        current_state.pop(-1)
        current_state.append(1000)
        # print('test_current_state',current_state)
        return current_state



class ExpansionEnv:

    def __init__(self, graph: Graph, selected_nodes: List[List[int]], max_size: int):
        self.max_size = max_size
        self.graph = graph
        self.n_nodes = self.graph.n_nodes
        self.data = selected_nodes
        self.bs = len(self.data)
        self.trajectories = None
        self.dones = None

    @property
    def lengths(self):
        return [len(x) - (x[-1] == 'EOS') for x in self.trajectories]

    @property
    def done(self):
        return all(self.dones)

    @property
    def valid_index(self) -> List[int]:
        return [i for i, d in enumerate(self.dones) if not d]

    def __len__(self):
        return len(self.data)

    def reset(self):
        self.trajectories = [x.copy() for x in self.data]
        self.dones = [x[-1] == 'EOS' or len(x) >= self.max_size or len(self.graph.outer_boundary(x)) == 0
                      for x in self.trajectories]
        assert not any(self.dones)
        seeds = [self.data[i][0] for i in range(self.bs)]
        nodes = [self.data[i] for i in range(self.bs)]
        x_seeds = self.make_single_node_encoding(seeds)


        x_nodes = self.make_nodes_encoding(nodes)


        return x_seeds, x_nodes



    def step(self, new_nodes: List[Union[int, str]], index: List[int],fn,reward_list):
        assert len(new_nodes) == len(index)

        full_new_nodes: List[Optional[int]] = [None for _ in range(self.bs)]
        for i, v in zip(index, new_nodes):

            self.trajectories[i].append(v)

            if v == 'EOS' or v == 1000:
                self.dones[i] = True
                self.trajectories[i].pop(-1)
            elif len(self.trajectories[i]) == self.max_size:
                self.dones[i] = True
            elif self.graph.outer_boundary(self.trajectories[i]) == 0:
                self.dones[i] = True
            else:
                full_new_nodes[i] = v

            if self.dones[i] == True:

                final_scores = fn([self.trajectories[i]],[self.data[i][0]]).item()

                reward_list[i] = final_scores


        delta_x_nodes = self.make_single_node_encoding(full_new_nodes)

        return delta_x_nodes, self.trajectories , reward_list, self.dones

    def make_single_node_encoding_2(self, nodes: List[int]) :
        ind = [ [nodes[0]],[0]]

        ind = np.asarray(ind, dtype=np.int64)
        data = np.ones(ind.shape[1], dtype=np.float32)
        return sp.csc_matrix((data, ind),  shape = [self.n_nodes,1])

    def make_single_node_encoding(self, nodes: List[int]):
        bs = len(nodes)

        assert bs == self.bs
        ind = np.array([[v, i] for i, v in enumerate(nodes) if v is not None], dtype=np.int64).T

        if len(ind):
            data = np.ones(ind.shape[1], dtype=np.float32)

            return sp.csc_matrix((data, ind), shape=[self.n_nodes, bs])
        else:
            return sp.csc_matrix((self.n_nodes, bs), dtype=np.float32)

    def make_nodes_encoding(self, nodes: List[List[int]]):
        bs = len(nodes)
        assert bs == self.bs
        ind = [[v, i] for i, vs in enumerate(nodes) for v in vs]
        ind = np.asarray(ind, dtype=np.int64).T
        if len(ind):
            data = np.ones(ind.shape[1], dtype=np.float32)


            return sp.csc_matrix((data, ind), shape=[self.n_nodes, bs])
        else:
            return sp.csc_matrix((self.n_nodes, bs), dtype=np.float32)


class Generator:

    def __init__(self,args, graph: Graph, model: Agent, optimizer,
                 device: Optional[torch.device] = None,
                 entropy_coef: float = 1e-6,
                 n_rollouts: int = 10,
                 max_size: int = 25,
                 k: int = 3,
                 alpha: float = 0.85,
                 max_reward: float = 1.,
                 ):
        self.args = args
        self.graph = graph
        self.model = model
        self.model_gflow = model
        self.optimizer = optimizer
        self.entropy_coef = entropy_coef
        self.max_reward = max_reward
        self.n_nodes = self.graph.n_nodes
        self.max_size = max_size #
        self.n_rollouts = n_rollouts
        self.conv = GraphConv(graph, k, alpha)
        self.nodefeats = None
        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = device

    def load_nodefeats(self, x):

        self.nodefeats = x.numpy()


    def generate(self, seeds: List[int], fn_score, max_size: Optional[int] = None):
        max_size = self.max_size if max_size is None else max_size
        env = ExpansionEnv(self.graph, [[s] for s in seeds], max_size)

        self.model.eval()
        with torch.no_grad():
            episodes, *_ = self._sample_trajectories(env,fn_score)

        return episodes

    def generate_2(self, seeds: List[int], fn, max_size: Optional[int] = None):
        max_size = self.max_size if max_size is None else max_size
        env = ExpansionEnv(self.graph, [[s] for s in seeds], max_size)

        self.model.eval()
        with torch.no_grad():
            episodes, *_ = self._sample_trajectories_2(env,fn)

        return episodes

    def sample_episodes(self, seeds: List[int], fn,max_size: Optional[int] = None):
        max_size = self.max_size if max_size is None else max_size
        env = ExpansionEnv(self.graph, [[s] for s in seeds], max_size)
        return self._sample_trajectories(env,fn)




    def prepare_flow_calculation(self,selected_nodes,rewards):
        batch =[]
        for j in range(len(selected_nodes)):
            item = selected_nodes[j]

            for i in range(1,len(item)-1):
                # parent, action , current, reward , done
                batch+=[[item[:i],item[i],item[:i+1],rewards[j][i], True if i==len(item)-2 else False]]
        return batch


    def prepare_new_inputs(self, trajectory,z_nodes,z_seeds):
        vals_attr = [] if self.nodefeats is not None else None
        vals_node = []
        vals_seed = []
        indptr = []
        offset = 0
        batch_candidates = []
        # for i in valid_index:
        trajectory_nodes = trajectory.int().numpy().tolist()
        boundary_nodes = self.graph.outer_boundary(trajectory_nodes)
        candidate_nodes = list(boundary_nodes)
        # assert len(candidate_nodes)
        involved_nodes = candidate_nodes + trajectory_nodes

        if 1000 in involved_nodes:
            involved_nodes = involved_nodes[:-1]

        batch_candidates.append(candidate_nodes)
        if self.nodefeats is not None:
            vals_attr.append(self.nodefeats[involved_nodes])
        vals_node.append(z_nodes.T[:, involved_nodes].todense())
        # vals_node.append(z_nodes.T[i, involved_nodes].todense())
        vals_seed.append(z_seeds.T[:,involved_nodes].todense())

        indptr.append((offset, offset + len(involved_nodes), offset + len(candidate_nodes)))
        offset += len(involved_nodes)
        if self.nodefeats is not None:
            # vals_attr = torch.cat(vals_attr, 0)
            vals_attr = np.concatenate(vals_attr, 0)
            vals_attr = torch.from_numpy(vals_attr).to(self.device)

        vals_seed = np.array(np.concatenate(vals_seed, 1))[0]
        vals_seed = torch.from_numpy(vals_seed).to(self.device)

        vals_node = np.array(np.concatenate(vals_node, 1))[0]
        vals_node = torch.from_numpy(vals_node).to(self.device)
        indptr = np.array(indptr)

        return vals_attr, vals_seed, vals_node, indptr, batch_candidates

    def learn_from(self,batch):

        loginf = tf([1000])
        batch_parent = []
        batch_action = []
        batch_sp =[]
        batch_done =[]
        batch_reward = []

        new_model_inputs_parent = []
        new_model_inputs_current = []

        for (parents, parents_z_node, action, sp, sp_z_node, z_node_seed,reward, done) in batch:
            for i in range(len(parents)):
                batch_parent += [[parents[i],parents_z_node[i],z_node_seed]]
                batch_action += [action[i]]
            batch_sp += [[sp,sp_z_node,z_node_seed]]
            batch_done += [done]
            batch_reward += [reward]


        new_model_inputs_parent += [self.prepare_new_inputs(trajectory,z_nodes,z_node_seed) for (trajectory, z_nodes,z_node_seed) in batch_parent]
        new_model_inputs_current += [self.prepare_new_inputs(trajectory,z_nodes,z_node_seed) for (trajectory, z_nodes,z_node_seed) in batch_sp]

        batch_logits_parent_list = []
        batch_logits_current_list = []

        values_parent_list = []
        values_current_list = []

        parent_Qsa = []

        batch_idxs = torch.LongTensor(
            sum([[i] * len(parents) for i, (parents,_, _,_,_,_,_,_) in enumerate(batch)], []))
        for item, action_index in zip(new_model_inputs_parent,batch_action):
            # vals_attr, vals_seed, vals_node, indptr, batch_candidates
            batch_logits, values = self.model(item[0],item[1],item[2],item[3])
            candidates = item[4][0]
            action_item  = int(action_index.item())
            if action_item != 1000:
                action_ix = candidates.index(action_item)
            else:
                action_ix = -1
            batch_logits_parent_list += [batch_logits]
            values_parent_list += [values]
            parent_Qsa += [batch_logits[0][action_ix]]



        for item in new_model_inputs_current:
            if len(item[0])!= 0:
                batch_logits, values = self.model(item[0],item[1],item[2],item[3])
            else:
                batch_logits = 0
                values = 0
            batch_logits_current_list += [batch_logits[0]]
            values_current_list += [values]
        parents_Qsa = torch.stack(parent_Qsa)

        in_flow = torch.Tensor(torch.zeros((len(new_model_inputs_current),))
                            .index_add_(0, batch_idxs, torch.exp(parents_Qsa)))
        done = torch.stack((batch_done)).squeeze(1)
        r = torch.stack((batch_reward)).squeeze(1)


        test_list= [i[0].unsqueeze(0) for i in batch_logits_current_list]

        next_qd = [ (test_list[i] * (1 - done[i]) + (done[i] * (-loginf))) for i in range(len(test_list))]

        test_a = torch.log(r)[:, None]

        of_list = []
        for i in range(len(test_a)):
            test_c = torch.sum(torch.exp(torch.cat([torch.log(r)[:, None][i],next_qd[i]],0)), 0)
            of_list.append(test_c)

        out_flow = torch.stack(of_list)



        loss = (in_flow - out_flow).pow(2).mean()


        term_loss = ((in_flow - out_flow) * done).pow(2).sum() / (done.sum() + 1e-20)
        flow_loss = ((in_flow - out_flow) * (1 - done)).pow(2).sum() / ((1 - done).sum() + 1e-20)

        loss_total = term_loss*25+flow_loss
        # loss_total = term_loss
        return loss, term_loss,flow_loss,loss_total


    def train_from_rewards(self, seeds: List[int], fn):
        bs = len(seeds)
        self.model.train()
        selected_nodes, logps, values, entropys, batch= self.sample_episodes(seeds,fn)
        losses = self.learn_from(batch)

        if losses is not None:

            self.optimizer.zero_grad()
            # loss_ = losses[0].requires_grad_()
            # losses[3].backward()
            losses[0].backward()
            self.optimizer.step()


        return losses


    def getCuttingPointAndCuttingEdge(self,edges: List[Tuple]):
        link, dfn, low = {}, {}, {}
        global_time = [0]
        for a, b in edges:
            if a not in link:
                link[a] = []
            if b not in link:
                link[b] = []
            link[a].append(b)
            link[b].append(a)
            dfn[a], dfn[b] = 0x7fffffff, 0x7fffffff
            low[a], low[b] = 0x7fffffff, 0x7fffffff

        cutting_points, cutting_edges = [], []

        def dfs(cur, prev, root):
            global_time[0] += 1
            dfn[cur], low[cur] = global_time[0], global_time[0]

            children_cnt = 0
            flag = False
            for next in link[cur]:
                if next != prev:
                    if dfn[next] == 0x7fffffff:
                        children_cnt += 1
                        dfs(next, cur, root)

                        if cur != root and low[next] >= dfn[cur]:
                            flag = True
                        low[cur] = min(low[cur], low[next])

                        if low[next] > dfn[cur]:
                            cutting_edges.append([cur, next] if cur < next else [next, cur])
                    else:
                        low[cur] = min(low[cur], dfn[next])

            if flag or (cur == root and children_cnt >= 2):
                cutting_points.append(cur)

        dfs(edges[0][0], None, edges[0][0])
        return cutting_points, cutting_edges

    def update_matrix(state, action_idx, adj, cut_vertex_list, cut_vertex_dict):
        action_edges = []
        action_link = []
        leng = len(state)

        for i in range(leng):
            if adj[state[i]][action_idx] > 0:
                action_edges.append((i, action_idx))
                action_link.append(i)

        if len(action_edges) == 1:
            cut_ = action_edges[0][0]
            if cut_ in cut_vertex_list:
                cut_vertex_dict[str(cut_)].append([action_idx])

                for item in cut_vertex_dict.keys():
                    for kk in cut_vertex_dict[item]:
                        if cut_ in kk:
                            kk.append(cut_)
                            continue
                return cut_vertex_list, cut_vertex_dict
            else:
                # print('Introduce new cut vertex!')
                cut_vertex_list.append(cut_)
                cut_vertex_dict[str(cut_)] = [[action_idx], state[:-1]]
                return cut_vertex_list, cut_vertex_dict
        else:
            cut_vertex_dict_copy = cut_vertex_dict
            for item in cut_vertex_dict.keys():
                true_list = []
                true_element = []
                false_element = []
                for kk in cut_vertex_dict[item]:
                    ind = False
                    for jj in kk:
                        if jj in action_link:
                            true_list.append(True)
                            true_element.append(jj)
                            ind = True
                            break
                    if not ind:
                        true_list.append(False)
                        false_element += kk
                if all(true_list):
                    # print('delete a cut vertex!')
                    cut_vertex_list.remove(int(item))
                    del cut_vertex_dict_copy[item]
                else:
                    true_element.append(action_idx)
                    cut_vertex_dict_copy[item] = [true_element, false_element]

                return cut_vertex_list, cut_vertex_dict_copy



    def parent_transition(self,index, state,z_node_vector,env,reward,z_seed):
        parent_state = []
        parent_action = []
        parent_z_nodes = []
        current_z_nodes = z_node_vector
        current_reward = reward
        # cutting_dots=[]
        cut_vertex_list = []
        cut_vertex_dict = dict()

        if state[-1] == 1000 or 'EOS' in state:

            state.pop(-1)

        if self.args.ablation == 1:
        # ---------- Version 1 --------------
            if state[-1] != 1000 and 'EOS' not in state:
                if len(state) > 2:
                    adj = env.graph.adj_mat.toarray()

        #
        #             edges = []
        #             leng = len(state)
        #             for i in range(leng):
        #                 for j in range(i, leng):
        #                     if adj[state[i]][state[j]] > 0:
        #                         edges.append((i, j))
        #
        #             # cutting_dots_index, _ = self.getCuttingPointAndCuttingEdge(edges)
        #             # cutting_dots = [state[item] for item in cutting_dots_index]
            cut_vertex_list, cut_vertex_dict = self.update_matrix(index, adj, cut_vertex_list, cut_vertex_dict)
            cutting_dots = cut_vertex_list
            for i in range(len(state)):
                ori_state = state.copy()

                deleted_node = ori_state[i]

                if deleted_node in cutting_dots or deleted_node in env.data[index]:

                    continue

                ori_state.pop(i)
                parent_z_node_vector = self.conv(env.make_single_node_encoding_2(ori_state))
                parent_state += [ori_state]
                parent_action += [deleted_node]
                parent_z_nodes += [parent_z_node_vector]

        elif self.args.ablation == 2:
            # ---------- Version 2 --------------
            for i in range(len(state)):
                ori_state = state.copy() # [651,650 ]

                deleted_node = ori_state[i]
                # print(650 in seed_)
                if deleted_node in env.data[index]:
                    continue

                # delete_ = ori_state[i]
                ori_state.pop(i)
                parent_z_node_vector = self.conv(env.make_single_node_encoding_2(ori_state))

                parent_state += [ori_state]
                parent_action += [deleted_node]
                parent_z_nodes += [parent_z_node_vector]

        elif self.args.ablation == 3:
                # ---------- Version 3 --------------
            for i in range(len(state)):
                ori_state = state.copy()  # [651,650 ]

                deleted_node = ori_state[i]
                # print(650 in seed_)

                # delete_ = ori_state[i]
                ori_state.pop(i)
                parent_z_node_vector = self.conv(env.make_single_node_encoding_2(ori_state))

                parent_state += [ori_state]
                parent_action += [deleted_node]
                parent_z_nodes += [parent_z_node_vector]

        else:

            ori_state = state.copy()
            deleted_node = ori_state[-1]
            ori_state.pop(-1)
            parent_z_node_vector = self.conv(env.make_single_node_encoding_2(ori_state))
            parent_state += [ori_state]
            parent_action += [deleted_node]
            parent_z_nodes += [parent_z_node_vector]

        return parent_state, parent_z_nodes, parent_action, env.dones[index], current_z_nodes, current_reward,z_seed



    def _sample_trajectories(self, env: ExpansionEnv,fn):
        bs = env.bs

        x_seeds, delta_x_nodes = env.reset()

        z_seeds = self.conv(x_seeds)
        z_nodes = sp.csc_matrix((self.n_nodes, bs), dtype=np.float32)


        z_nodes_update = sp.csc_matrix((self.n_nodes, bs), dtype=np.float32)
        episode_logps = [[] for _ in range(bs)]
        episode_values = [[] for _ in range(bs)]
        episode_entropys = [[] for _ in range(bs)]

        batch = []

        reward_list = [0] * bs
        z_nodes_update += self.conv(delta_x_nodes)

        step=0
        while not env.done:
            step+=1
            z_nodes += self.conv(delta_x_nodes)

            valid_index = env.valid_index
            *model_inputs, batch_candidates = self._prepare_inputs(valid_index, env.trajectories, z_nodes, z_seeds)



            batch_logits, values = self.model(*model_inputs)



            actions, logps, entropys = self._sample_actions(batch_logits,step,batch_candidates)

            new_nodes = [x[i] if i < len(x) else 'EOS' for i, x in zip(actions, batch_candidates)]


            step_new = env.step(new_nodes, valid_index ,fn,reward_list)

            delta_x_nodes= step_new[0]
            z_nodes_update += self.conv(delta_x_nodes)




            p_a = [self.parent_transition(i,step_new[1][i],z_nodes_update[:,i],env,step_new[2][i],z_seeds[:,i]) for i in range(bs)]

            batch += [[item for item in (tf(parent_state).to(self.device), parent_z_nodes,
                                         tf(eos_to_number(action)).to(self.device),
                                         tf(eos_to_number(current_state)).to(self.device),
                                         current_z_nodes,
                                         z_node_seed,
                                         tf([reward]), tf([done]).to(self.device))]
                      for (parent_state, parent_z_nodes, action, done, current_z_nodes,reward,z_node_seed), current_state in
                      zip(p_a, step_new[1])]


            for i, v1, v2, v3 in zip(valid_index, logps, values, entropys):
                episode_logps[i].append(v1)
                episode_values[i].append(v2)
                episode_entropys[i].append(v3)
        # Stack and Padding
        logps, values, entropys = [nn.utils.rnn.pad_sequence([torch.stack(x) for x in episode_xs], True)
                                   for episode_xs in [episode_logps, episode_values, episode_entropys]]

        return env.trajectories, logps, values, entropys , batch


    def _sample_trajectories_2(self, env: ExpansionEnv,fn):
        bs = env.bs

        x_seeds, delta_x_nodes = env.reset()
        # delta_x_nodes is the z_nodes here
        z_seeds = self.conv(x_seeds)
        z_nodes = sp.csc_matrix((self.n_nodes, bs), dtype=np.float32)


        z_nodes_update = sp.csc_matrix((self.n_nodes, bs), dtype=np.float32)

        batch = []

        reward_list = [0] * bs
        z_nodes_update += self.conv(delta_x_nodes)

        step=0
        while not env.done:
            step+=1
            z_nodes += self.conv(delta_x_nodes)

            valid_index = env.valid_index
            *model_inputs, batch_candidates = self._prepare_inputs(valid_index, env.trajectories, z_nodes, z_seeds)


            batch_logits, values = self.model(*model_inputs)

            actions, logps, entropys = self._sample_actions(batch_logits,step,batch_candidates)

            new_nodes = [x[i] if i < len(x) else 'EOS' for i, x in zip(actions, batch_candidates)]

            step_new = env.step(new_nodes, valid_index ,fn,reward_list)

            delta_x_nodes= step_new[0]
            z_nodes_update += self.conv(delta_x_nodes)


        return env.trajectories, logps, values, entropys , batch

    def _prepare_inputs(self, valid_index: List[int], trajectories: List[List[int]],
                        z_nodes: sp.csc_matrix, z_seeds: sp.csc_matrix):
        vals_attr = [] if self.nodefeats is not None else None
        vals_seed = []
        vals_node = []
        indptr = []
        offset = 0
        batch_candidates = []
        for i in valid_index:
            boundary_nodes = self.graph.outer_boundary(trajectories[i])
            candidate_nodes = list(boundary_nodes)
            # assert len(candidate_nodes)
            involved_nodes = candidate_nodes + trajectories[i]
            batch_candidates.append(candidate_nodes)
            if self.nodefeats is not None:
                vals_attr.append(self.nodefeats[involved_nodes])
            vals_seed.append(z_seeds.T[i, involved_nodes].todense())
            vals_node.append(z_nodes.T[i, involved_nodes].todense())
            indptr.append((offset, offset + len(involved_nodes), offset + len(candidate_nodes)))
            offset += len(involved_nodes)
        if self.nodefeats is not None:
            # vals_attr = torch.cat(vals_attr, 0)
            vals_attr = np.concatenate(vals_attr, 0)
            vals_attr = torch.from_numpy(vals_attr).to(self.device)
        vals_seed = np.array(np.concatenate(vals_seed, 1))[0]
        vals_node = np.array(np.concatenate(vals_node, 1))[0]
        vals_seed = torch.from_numpy(vals_seed).to(self.device)
        vals_node = torch.from_numpy(vals_node).to(self.device)
        indptr = np.array(indptr)
        return vals_attr, vals_seed, vals_node, indptr, batch_candidates

    def _sample_actions(self, batch_logits: List,step,bc) -> (List, List, List):
        batch = []
        temp=0
        for logits in batch_logits:
            ps = torch.exp(logits)
            entropy = -(ps * logits).sum()


            action = torch.multinomial(ps, 1).item()
            while(action>=len(bc[temp]) and step < 5 ):
                action = torch.multinomial(ps, 1).item()

            logp = logits[action]
            batch.append([action, logp, entropy])
            temp+=1
        actions, logps, entropys = zip(*batch)
        actions = np.array(actions)
        return actions, logps, entropys
