import random
import math
from dataclasses import dataclass, field

import numpy as np
import torch
from torch.utils.data import IterableDataset
import torch_geometric.data as gd
import igraph as ig

from gflownet.graphenv import ActionType, GraphEnv, MolEnv, GraphState, GraphAction
from gflownet.featurizer import GraphStateFeaturizer
from gflownet.models import PolicyNet



def count_automorphisms(state):
    if state.num_nodes <= 1:
        return 1
    g = ig.Graph(state.edge_list)
    return g.count_automorphisms_vf2(color=state.node_types, edge_color=state.edge_types)



@dataclass
class Trajectory:
    '''
    Complete trajectory has N elements in states, actions, torch_graphs and fwd_action_idxs, 
    while there are (N - 1) elements in bck_action_idxs and bck_log_probs.

    '''
    states: list[GraphState] = field(default_factory=list)
    actions: list[GraphAction] = field(default_factory=list)
    torch_graphs: list[gd.Data] = field(default_factory=list)
    fwd_action_idxs: list[int] = field(default_factory=list)
    bck_action_idxs: list[int] = field(default_factory=list)
    bck_log_probs: list[float] = field(default_factory=list)
    log_reward: float = None

    
@dataclass
class Transition:
    graph: gd.Data
    graph_: gd.Data
    fwd_action_idx: int = None
    bck_log_prob: float = None
    log_reward: float = None


def to_transitions(traj, automorphism_correction=False):
    transitions = [] 
    graphs = traj.torch_graphs
    
    if automorphism_correction:
        autm = np.array([count_automorphisms(state) for state in traj.states])
        log_autm = np.log(autm)
        log_autm_ratio = log_autm[1:] - log_autm[:-1]
    else:
        log_autm_ratio = np.zeros(len(graphs)-1)
        
    for i in range(len(graphs) - 1):
        tr = Transition(graphs[i], graphs[i+1])
        tr.fwd_action_idx = traj.fwd_action_idxs[i]
        tr.bck_log_prob = traj.bck_log_probs[i] + log_autm_ratio[i]
        transitions.append(tr)
    
    tr = Transition(graphs[-1], None, traj.fwd_action_idxs[-1], None, traj.log_reward)
    transitions.append(tr)
    return transitions


class GraphSampler:
    def __init__(self, env: GraphEnv, featurizer: GraphStateFeaturizer):
        self.env = env
        self.featurizer = featurizer
    
    @torch.no_grad()
    def sample_from_model(self, model: PolicyNet, num_samples: int, epsilon: float = 0.0) -> list[Trajectory]:
        dones = [False] * num_samples
        states = [self.env.new() for _ in range(num_samples)]
        trajectories = [Trajectory() for _ in range(num_samples)]
        
        while states:
            states_info = self.featurizer.transform(states)
            torch_graphs, children = states_info['graph'], states_info['children']
            if epsilon < random.random():
                actions_idxs = [random.randint(0, len(ch.states)-1) for ch in children]
            else:
                batch = self.featurizer.collate(torch_graphs)
                fwd_dist, _, _ = model(batch)
                actions_idxs = fwd_dist.sample().tolist()

            graph_actions = [children[i].actions[idx] for i, idx in enumerate(actions_idxs)]
            not_dones = [i for i, done in enumerate(dones) if not done]
            next_states = []
            
            for i, j in zip(not_dones, range(len(states))):
                tr = trajectories[i]
                tr.states.append(states[j])
                tr.torch_graphs.append(torch_graphs[j])
                tr.actions.append(graph_actions[j])
                tr.fwd_action_idxs.append(actions_idxs[j])
                
                if graph_actions[j].type == ActionType.Stop:
                    dones[i] = True
                else:
                    c, idx = children[j].states, actions_idxs[j]
                    next_states.append(c[idx])
            states = next_states
        return trajectories
    

    def unif_backward_sample(self, state: GraphState):
        states = [self.env.new()] + [None] * state.num_edges + [state]
        bck_actions = [None] * (state.num_edges + 1)
        bck_log_probs = [None] * (state.num_edges + 1)
        bck_action_idxs = [None] * (state.num_edges + 1)
        
        for t in reversed(range(len(bck_actions))):
            parents = self.env.parents(state)
            i = random.randint(0, len(parents) - 1)
            bck_action_idxs[t], states[t], bck_actions[t] = i, parents.states[i], parents.actions[i]
            bck_log_probs[t] = -math.log(len(parents))
            state = states[t]

        actions = [self.env.reverse_action(s, a) for s, a in zip(states[1:], bck_actions)] + [self.env.stop_action()]
        states_info = self.featurizer.transform(states)
        torch_graphs, children = states_info['graph'], states_info['children']
        fwd_action_idxs = [ch.find_action_index(act) for act, ch in zip(actions, children)]
        traj = Trajectory(states, actions, torch_graphs, fwd_action_idxs, bck_action_idxs, bck_log_probs)
        return traj
    
    def fast_backward_sample(self, state: GraphState):
        states = [self.env.new()]
        ni = 0
        for ei in range(state.num_edges + 1):
            edge_types, edge_list = state.edge_types[:ei], state.edge_list[:ei]
            ni = max(ni, ei and max(edge_list[-1]))
            st = GraphState(state.node_types[:ni+1], edge_types, edge_list)
            states.append(st)
        states_info = self.featurizer.transform(states)
        torch_graphs, children = states_info['graph'], states_info['children']
        fwd_action_idxs = [ch.states.index(s) for s, ch in zip(states[1:], children)] + [0] # manually add stop action idx
        actions = [ch.actions[i] for i, ch in zip(fwd_action_idxs, children)]
        return Trajectory(states, actions, torch_graphs, fwd_action_idxs)
    


class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.buffer = []
        self.position = 0

    def push(self, obj):
        if self.buffer_size != 0: # not using buffer
            if len(self.buffer) < self.buffer_size:
                self.buffer.append(None)
            self.buffer[self.position] = obj
            self.position = (self.position + 1) % self.buffer_size

    def push_many(self, objs):
        for obj in objs:
            self.push(obj)

    def sample(self, batch_size):
        out = []
        if len(self.buffer):
            indices = np.random.choice(len(self.buffer), batch_size)
            out = [self.buffer[i] for i in indices]
        return out

    def __len__(self):
        return len(self.buffer)
    


class TrajectoryDataset(IterableDataset):
    def __init__(
            self, 
            model: PolicyNet, 
            env: GraphEnv, 
            featurizer: GraphStateFeaturizer, 
            num_online_samples = 16,
            num_buffer_samples = 0,
            buffer_size = 0,
            exploration_epsilon=0.0,
            compute_unif_backward_prob = True,
            compute_backward_action_idxs = False,
            transition_dataset = False,
            compute_automorphism = False,
            scale_reward = False
        ):
        self.model = model
        self.env = env
        self.featurizer = featurizer
        self.num_online_samples = num_online_samples
        self.num_buffer_samples = num_buffer_samples
        self.sampler = GraphSampler(env, featurizer)
        self.replay_buffer = ReplayBuffer(buffer_size)
        self.exploration_epsilon = exploration_epsilon
        self.num_total_samples = 0
        self.compute_unif_backward_prob = compute_unif_backward_prob
        self.compute_backward_action_idxs = compute_backward_action_idxs
        self.transition_dataset = transition_dataset
        self.compute_automorphism = compute_automorphism
        self.scale_reward = scale_reward
            

    def __iter__(self):
        while True:
            online_data = self.sampler.sample_from_model(self.model, self.num_online_samples, self.exploration_epsilon)
            if self.compute_unif_backward_prob:
                self.set_bck_log_probs(online_data)
            if self.compute_backward_action_idxs:
                self.set_bck_action_idxs(online_data)
            self.set_log_reward(online_data)
            buffer_data = self.replay_buffer.sample(self.num_buffer_samples)

            if self.scale_reward:
                self.scale_reward_by_automorphisms(online_data)

            if self.transition_dataset:
                online_data = self.flatten_trajectories(online_data)
                agg_data = online_data + buffer_data
                yield self.make_transition_batch(agg_data)
            else:
                agg_data = online_data + buffer_data
                yield self.make_trajectory_batch(agg_data)
            
            self.num_total_samples += len(online_data)
            self.replay_buffer.push_many(online_data)

    def flatten_trajectories(self, trajectories: list[Trajectory]) -> list[Transition]:
        return sum([to_transitions(traj, self.compute_automorphism) for traj in trajectories], [])

    def set_bck_log_probs(self, trajectories: Trajectory):
        for tr in trajectories:
            tr.bck_log_probs = [-math.log(len(self.env.parents(s))) for s in tr.states[1:]]

    def set_log_reward(self, trajectories: Trajectory):
        for tr in trajectories:
            last_state = tr.states[-1]
            tr.log_reward = self.env.log_reward(last_state)

    def set_bck_action_idxs(self, trajectories):
        for tr in trajectories:
            bck_actions = [self.env.reverse_action(s, a) for s, a in zip(tr.states[:-1], tr.actions[:-1])]
            parents = [self.env.parents(s) for s in tr.states[1:]]
            bck_action_idxs = [p.find_action_index(a) for p, a in zip(parents, bck_actions)]
            tr.bck_action_idxs = bck_action_idxs

    def scale_reward_by_automorphisms(self, trajectories: Trajectory):
        for tr in trajectories:
            last_state = tr.states[-1]
            aut = count_automorphisms(last_state)
            tr.log_reward = tr.log_reward + math.log(aut)

    def make_trajectory_batch(self, trajectories):
        batch = self.featurizer.collate(sum([tr.torch_graphs for tr in trajectories], []))
        trajectory_lens = torch.tensor([len(tr.actions) for tr in trajectories], dtype=torch.long)
        batch.tbatch = torch.repeat_interleave(trajectory_lens)
        batch.log_rewards = torch.tensor([tr.log_reward for tr in trajectories], dtype=torch.float)
        batch.fwd_action_idxs = torch.tensor(sum([tr.fwd_action_idxs for tr in trajectories], []), dtype=torch.long)
        batch.num_trajectories = len(trajectories)
        if self.compute_unif_backward_prob:
            batch.tr_bck_log_probs = torch.tensor([sum(tr.bck_log_probs) for tr in trajectories], dtype=torch.float)
        if self.compute_backward_action_idxs:
            batch.rbatch = torch.repeat_interleave(trajectory_lens - 1)
            batch.bck_action_idxs = torch.tensor(sum([tr.bck_action_idxs for tr in trajectories], []), dtype=torch.long)
        return batch
    
    def make_transition_batch(self, transitions):
        batch = self.featurizer.collate([tr.graph for tr in transitions])
        batch.next_batch = self.featurizer.collate([tr.graph_ for tr in transitions if tr.graph_ is not None])
        batch.non_terminals = torch.tensor([i for i, tr in enumerate(transitions) if tr.graph_ is not None], dtype=torch.long)
        batch.fwd_action_idxs = torch.tensor([tr.fwd_action_idx for tr in transitions], dtype=torch.long)
        batch.targets = torch.tensor([(tr.log_reward if tr.log_reward is not None else tr.bck_log_prob) for tr in transitions], dtype=torch.float)
        return batch
