import itertools
import networkx as nx
from collections import defaultdict, namedtuple, deque
from tqdm import tqdm
import random
import numpy as np


import torch
from torch import nn
from torch.distributions import Categorical
import torch.nn.functional as F
import matplotlib.pyplot as plt


State = namedtuple("State", ["bits", "is_terminal"])

def state_to_str(state):
    return ''.join([str(int(i)) for i in state.bits])

class BitFlipEnv:    
    def __init__(self, dim=5):
        self.dim = dim
        self.G = None

    def init(self):
        return State(bits=[0] * self.dim, is_terminal=False)

    def to_tensor(self, state: State):
        return torch.tensor(state.bits, dtype=torch.float)
    
    def forward_step(self, state, action):
        assert not state.is_terminal
        assert state.bits[action] == 0
        bits = state.bits.copy()
        bits[action] = 1
        return State(bits=bits, is_terminal=bits[-1] == 1)

    def backward_step(self, state, action):
        assert sum(state.bits) != 0
        assert state.bits[action] == 1
        bits = state.bits.copy()
        bits[action] = 0
        return State(bits=bits, is_terminal=False)

    def uniform_backward_probability(self, state):
        if state.is_terminal:
            return 1.0
        else:
            num_bits = sum(state.bits)
            return 1. / num_bits
        
    def sample_trajectory_from_terminal(self, state):
        assert state.is_terminal
        traj = [self.init()]
        actions = [i for i, b in enumerate(state.bits[:-1]) if b == 1]
        random.shuffle(actions)
        actions.append(self.dim-1)
        bits = [0] * self.dim
        for i in actions:
            bits = bits.copy()
            bits[i] = 1
            is_terminal = True if bits[-1] == 1 else False
            traj.append(State(bits, is_terminal))
        return traj, actions

    @torch.no_grad()
    def sample_trajectory_from_model(self, model):
        state = self.init()
        traj = [state]
        actions = []
        while not state.is_terminal:
            action = model.forward(self.to_tensor(state)).sample().item()
            state = self.forward_step(state, action)
            traj.append(state)
            actions.append(action)
        return traj, actions
    
    def compute_state_probabilities(self, model):
        with torch.no_grad():
            all_states = self.get_all_states()
            all_tensor = map(self.to_tensor, all_states[:-1])
            all_tensor = torch.vstack(list(all_tensor))
            states_str = [state_to_str(x) for x in all_states]
            fwd = model.forward(all_tensor)
            probs = fwd.logits.softmax(dim=1)
            probs = dict(zip(states_str, probs.tolist()))

        def neighbors_iter(state):
            dim = len(state)
            for i in range(dim):
                if state[-1] == '0' and state[i] == '0':
                    yield state[:i] + '1' + state[i+1:]
        
        def compute_action(state, next_state):
            diff = [int(y) - int(x) for x, y in zip(state, next_state)]
            return max(range(model.dim), key=lambda i: diff[i]) # argmax

        s0 = '0' * model.dim
        queue = deque([s0])
        state_prob = {s0: 1.0}
        while queue:
            cur = queue.popleft()
            for nei in neighbors_iter(cur):
                action = compute_action(cur, nei)
                if nei in state_prob:
                    state_prob[nei] += state_prob[cur] * probs[cur][action]
                else:
                    state_prob[nei] = state_prob[cur] * probs[cur][action]
                    queue.append(nei)
        return state_prob

    def get_all_states(self):
        return [State(list(bits), bits[-1] == 1) for bits in itertools.product((0, 1), repeat=self.dim)]
    
    def get_terminal_states(self):
        return [State(list(bits) + [1], True) for bits in itertools.product((0, 1), repeat=self.dim - 1)]

    def to_transitions(self, traj, actions):
        tensors = [self.to_tensor(state) for state in traj]
        prev_states = torch.vstack(tensors[:-1])
        actions = torch.tensor(actions)
        next_states = torch.vstack(tensors[1:])
        dones = torch.tensor([state.is_terminal for state in traj[1:]])
        backward_probs = torch.tensor([self.uniform_backward_probability(s) for s,a in zip(traj[1:], actions)])
        return prev_states, actions, next_states, dones, backward_probs
    
    def build_graph(self):
        states = [state_to_str(x) for x in self.get_all_states()]
        adj = defaultdict(list)
        for s in states:
            if s[-1] == '1':
                continue
            for i, b in enumerate(s):
                if b == '0':
                    s_ = s[:i] + '1' + s[i + 1 :]
                    adj[s] += [s_]
        G = nx.DiGraph()
        G.add_edges_from([(k, i) for k, v in adj.items() for i in v])
        return G

    def compute_postion(self, G):
        # BFS
        s0 = '0' * self.dim
        frontier = deque([s0])
        pos = {s0: [0, 0]}
        depth2height = defaultdict(int)
        depth2height[0] = 1
        while frontier:
            cur = frontier.popleft()
            new_nodes = sorted(G.adj[cur].keys(), reverse=False)
            for n in new_nodes:
                if n not in pos:
                    d = pos[cur][0] + 1
                    pos[n] = [d, depth2height[d]]
                    depth2height[d] += 1
                    frontier.append(n)

        # recompute position for pretty visualization
        max_height = max(depth2height.values())
        for p in pos:
            d, h = pos[p]
            gap = max_height / (depth2height[d] + 1)
            pos[p][1] = h * gap + gap

        return pos
    
    def visualize(self, model, labels=True):
        if self.G is None:
            self.G = self.build_graph()
            self.pos = self.compute_postion(self.G)
        state_probs = self.compute_state_probabilities(model)
        node_size = [1000*state_probs[n] for n in self.G.nodes()]
        node_color = ["#56B4E9" if n[-1] == '0' else "#CC79A7" for n in self.G.nodes()]
        nx.draw_networkx(self.G, pos=self.pos, with_labels=False, node_color=node_color, node_shape='o', node_size=node_size, width=0.4)
        if labels:
            for k, v in self.pos.items():
                self.pos[k][1] += 0.3
            self.pos['0000'][1] += 0.1
            nx.draw_networkx_labels(self.G, self.pos, font_size=6, font_color='#000000')
        return 
    
    def l1_error(self, model):
        probs = self.compute_state_probabilities(model)
        states = self.get_terminal_states()
        tprobs = [probs[state_to_str(s)] for s in states]
        return np.abs(np.array(tprobs) - 1 / len(states)).mean()


class BitFlipPolicy(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self.dim = env.dim
        self.forward_mlp = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, self.dim)
        )
        self.backward_mlp = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, self.dim)
        )
        self.logflow_mlp = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 1)
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, x):
        logits = self.forward_mlp(x)
        logits.masked_fill_(x == 1.0, -torch.inf)
        return Categorical(logits=logits)

    def backward(self, x):
        logits = self.backward_mlp(x)
        logits.masked_fill_(x == 0.0, -torch.inf)
        terminal_mask = (x[:, [-1]] == 1.0) # select terminal states: shape (n, 1)
        terminal_mask = terminal_mask * torch.BoolTensor([True] * (self.dim - 1) + [False], device=self.device) # mask logits but for the last dim
        logits[terminal_mask] = -torch.inf
        return Categorical(logits=logits)

    def logflow(self, x):
        return self.logflow_mlp(x).flatten()

    def logZ(self):
        init_state = self.env.init()
        init_tensor = self.env.to_tensor(init_state).to(self.device)
        return self.logflow(init_tensor)

    def logits(self, x):
        return self.forward_mlp(x).masked_fill_(x == 1.0, -torch.inf)


def get_sqil_loss(model, transitions, expert=True):
    prev_states, actions, next_states, dones, _ = transitions
    Q = model.logits(prev_states)
    Q_ = model.logits(next_states).detach()
    Q_sa = Q.gather(dim=1, index=actions.view(-1, 1)).squeeze(1)
    r = 1.0 if expert else 0.0
    rT = r * (model.dim - len(actions))
    V = torch.logsumexp(Q_, dim=1)
    V = torch.where(dones, rT, V)
    loss = (Q_sa - r - V).square().sum()
    return loss

def get_tbil_loss(model, transitions, reward):
    prev_states, actions, next_states, dones, backward_probs = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    logPB = torch.log(backward_probs)
    logZ = model.logZ()
    return (logZ + logPF.sum() - logPB.sum() - reward).square().sum()


def get_dbil_loss(model, transitions, reward):
    prev_states, actions, next_states, dones, backward_probs = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    logPB = torch.log(backward_probs)
    logF0 = model.logflow(prev_states)
    logF1 = model.logflow(next_states).detach()
    rT = reward * (model.dim - len(actions))
    logF1 = torch.where(dones, rT, logF1)
    loss = (logF0 + logPF - logPB - logF1 - reward).square()
    return loss.sum()

def train(dim, loss_type='tb', n_iter=5000, lr=0.001, r_alpha=1):
    env = BitFlipEnv(dim=dim)
    model = BitFlipPolicy(env)
    data = env.get_terminal_states()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    stats = defaultdict(list)
    for update in tqdm(range(n_iter)):
        loss = 0.0
        for state in data:
            traj, actions = env.sample_trajectory_from_terminal(state)
            transitions = env.to_transitions(traj, actions)
            if loss_type == 'tbil':
                loss = loss + get_tbil_loss(model, transitions, r_alpha)
            elif loss_type == 'dbil':
                loss = loss + get_dbil_loss(model, transitions, r_alpha)
            elif loss_type == 'sqil':
                loss = loss + get_sqil_loss(model, transitions, expert=True)
            
        for i in range(len(data)):
            traj, actions = env.sample_trajectory_from_model(model)
            transitions = env.to_transitions(traj, actions)
            if loss_type == 'tbil':
                loss = loss + get_tbil_loss(model, transitions, -r_alpha)
            elif loss_type == 'dbil':
                loss = loss + get_dbil_loss(model, transitions, -r_alpha)
            elif loss_type == 'sqil':
                loss = loss + get_sqil_loss(model, transitions, expert=False)
    
        loss = loss / (2 * len(data))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()    
        stats['loss'].append(loss.item())
    return env, model, stats