import os
import logging 
import argparse
import itertools
import networkx as nx
from collections import defaultdict, namedtuple, deque
import math
import random
import numpy as np


import torch
from torch import nn
from torch.distributions import Categorical


State = namedtuple("State", ["bits", "is_terminal"])
Batch = namedtuple('Batch', ['fwd_batch', 'bck_batch', 'actions', 'dones', 'traj_length'])

def get_logger(path=None, level="DEBUG"):
    logger = logging.getLogger()
    logger.setLevel(level)
    if path is None:
        logger.addHandler(logging.StreamHandler())
    else:
        logger.addHandler(logging.FileHandler(path, mode="a"))
    return logger


def state_to_str(state):
    return ''.join([str(int(i)) for i in state.bits])

class BitFlipEnv:    
    def __init__(self, dim=5):
        self.dim = dim
        self.G = None

    def init(self):
        return State(bits=[0] * self.dim, is_terminal=False)

    def to_tensor(self, state: State):
        return torch.tensor(state.bits, dtype=torch.float)
    
    def forward_step(self, state, action):
        assert not state.is_terminal
        assert state.bits[action] == 0
        bits = state.bits.copy()
        bits[action] = 1
        return State(bits=bits, is_terminal=bits[-1] == 1)

    def backward_step(self, state, action):
        assert sum(state.bits) != 0
        assert state.bits[action] == 1
        bits = state.bits.copy()
        bits[action] = 0
        return State(bits=bits, is_terminal=False)

    def uniform_backward_probability(self, state):
        if state.is_terminal:
            return 1.0
        else:
            num_bits = sum(state.bits)
            return 1. / num_bits
        
    def sample_trajectory_from_terminal(self, state):
        assert state.is_terminal
        traj = [self.init()]
        actions = [i for i, b in enumerate(state.bits[:-1]) if b == 1]
        random.shuffle(actions)
        actions.append(self.dim-1)
        bits = [0] * self.dim
        for i in actions:
            bits = bits.copy()
            bits[i] = 1
            is_terminal = True if bits[-1] == 1 else False
            traj.append(State(bits, is_terminal))
        return traj, actions
    
    def batch_sample_trajectory_from_terminal(self, states):
        trjs = [self.sample_trajectory_from_terminal(x) for x in states]
        trajectories, actions = zip(*trjs)
        return trajectories, actions

    @torch.no_grad()
    def sample_trajectory_from_model(self, model):
        state = self.init()
        traj = [state]
        actions = []
        while not state.is_terminal:
            action = model.forward(self.to_tensor(state)).sample().item()
            state = self.forward_step(state, action)
            traj.append(state)
            actions.append(action)
        return traj, actions
        
    @torch.no_grad()
    def batch_sample_trajectory_from_model(self, model, batch_size: int):
        # initialize batch of states
        states = [self.init() for _ in range(batch_size)]
        trajectories = [[s] for s in states]  # list of list of states
        actions = [[] for _ in range(batch_size)]  # list of list of actions
        finished = [s.is_terminal for s in states]

        while not all(finished):
            # convert non-terminal states into a batch tensor
            nonterm_idx = [i for i, f in enumerate(finished) if not f]
            nonterm_states = [states[i] for i in nonterm_idx]
            input_tensor = torch.vstack([self.to_tensor(x) for x in nonterm_states])  # should handle batch input

            # model forward over batch
            dists = model(input_tensor)  # expect distribution over actions
            sampled_actions = dists.sample().cpu().tolist()  # list of actions

            # update only non-terminal states
            for j, i in enumerate(nonterm_idx):
                a = sampled_actions[j]
                s_next = self.forward_step(states[i], a)
                states[i] = s_next
                trajectories[i].append(s_next)
                actions[i].append(a)
                finished[i] = s_next.is_terminal

        return trajectories, actions
    
    def compute_state_probabilities(self, model):
        with torch.no_grad():
            all_states = self.get_all_states()
            all_tensor = map(self.to_tensor, all_states[:-1])
            all_tensor = torch.vstack(list(all_tensor))
            states_str = [state_to_str(x) for x in all_states]
            fwd = model.forward(all_tensor)
            probs = fwd.logits.softmax(dim=1)
            probs = dict(zip(states_str, probs.cpu().tolist()))

        def neighbors_iter(state):
            dim = len(state)
            for i in range(dim):
                if state[-1] == '0' and state[i] == '0':
                    yield state[:i] + '1' + state[i+1:]
        
        def compute_action(state, next_state):
            diff = [int(y) - int(x) for x, y in zip(state, next_state)]
            return max(range(model.dim), key=lambda i: diff[i]) # argmax

        s0 = '0' * model.dim
        queue = deque([s0])
        state_prob = {s0: 1.0}
        while queue:
            cur = queue.popleft()
            for nei in neighbors_iter(cur):
                action = compute_action(cur, nei)
                if nei in state_prob:
                    state_prob[nei] += state_prob[cur] * probs[cur][action]
                else:
                    state_prob[nei] = state_prob[cur] * probs[cur][action]
                    queue.append(nei)
        return state_prob

    def get_all_states(self):
        return [State(list(bits), bits[-1] == 1) for bits in itertools.product((0, 1), repeat=self.dim)]
    
    def get_terminal_states(self):
        return [State(list(bits) + [1], True) for bits in itertools.product((0, 1), repeat=self.dim - 1)]

    def to_transitions(self, traj, actions):
        tensors = [self.to_tensor(state) for state in traj]
        prev_states = torch.vstack(tensors[:-1])
        actions = torch.tensor(actions)
        next_states = torch.vstack(tensors[1:])
        dones = torch.tensor([state.is_terminal for state in traj[1:]])
        backward_probs = torch.tensor([self.uniform_backward_probability(s) for s,a in zip(traj[1:], actions)])
        return prev_states, actions, next_states, dones, backward_probs
    
    def build_graph(self):
        states = [state_to_str(x) for x in self.get_all_states()]
        adj = defaultdict(list)
        for s in states:
            if s[-1] == '1':
                continue
            for i, b in enumerate(s):
                if b == '0':
                    s_ = s[:i] + '1' + s[i + 1 :]
                    adj[s] += [s_]
        G = nx.DiGraph()
        G.add_edges_from([(k, i) for k, v in adj.items() for i in v])
        return G

    def compute_postion(self, G):
        # BFS
        s0 = '0' * self.dim
        frontier = deque([s0])
        pos = {s0: [0, 0]}
        depth2height = defaultdict(int)
        depth2height[0] = 1
        while frontier:
            cur = frontier.popleft()
            new_nodes = sorted(G.adj[cur].keys(), reverse=False)
            for n in new_nodes:
                if n not in pos:
                    d = pos[cur][0] + 1
                    pos[n] = [d, depth2height[d]]
                    depth2height[d] += 1
                    frontier.append(n)

        # recompute position for pretty visualization
        max_height = max(depth2height.values())
        for p in pos:
            d, h = pos[p]
            gap = max_height / (depth2height[d] + 1)
            pos[p][1] = h * gap + gap

        return pos
    
    def visualize(self, model, labels=True):
        if self.G is None:
            self.G = self.build_graph()
            self.pos = self.compute_postion(self.G)
        state_probs = self.compute_state_probabilities(model)
        node_size = [1000*state_probs[n] for n in self.G.nodes()]
        node_color = ["#56B4E9" if n[-1] == '0' else "#CC79A7" for n in self.G.nodes()]
        nx.draw_networkx(self.G, pos=self.pos, with_labels=False, node_color=node_color, node_shape='o', node_size=node_size, width=0.4)
        if labels:
            for k, v in self.pos.items():
                self.pos[k][1] += 0.3
            self.pos['0000'][1] += 0.1
            nx.draw_networkx_labels(self.G, self.pos, font_size=6, font_color='#000000')
        return 
    
    def l1_error(self, model):
        probs = self.compute_state_probabilities(model)
        states = self.get_terminal_states()
        tprobs = [probs[state_to_str(s)] for s in states]
        return np.abs(np.array(tprobs) - 1 / len(states)).mean()



class DataSampler:
    def __init__(self, k=4, blocks=None):
        self.k = k
        self.blocks = blocks

    def sample(self, size=1):
        symbols = []
        for i in range(size):
            symbols.append(''.join(random.choices(self.blocks, k=self.k)) + '1')
        tensors = []
        for seq in symbols:
            tensors.append(State([int(s) for s in seq], is_terminal=True))
        return tensors
    

class BitFlipPolicy(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self.dim = env.dim
        self.forward_mlp = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, self.dim)
        )
        self.backward_mlp = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, self.dim)
        )
        self.logflow_mlp = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 1)
        )

    @property
    def device(self):
        return next(self.parameters()).device
        
    def forward(self, x):
        x = x.to(self.device)
        logits = self.forward_mlp(x)
        logits.masked_fill_(x == 1.0, -torch.inf)
        return Categorical(logits=logits)

    def backward(self, x):
        x = x.to(self.device)
        logits = self.backward_mlp(x)
        logits.masked_fill_(x == 0.0, -torch.inf)
        terminal_mask = (x[:, [-1]] == 1.0) # select terminal states: shape (n, 1)
        terminal_mask = terminal_mask * torch.BoolTensor([True] * (self.dim - 1) + [False]).to(self.device) # mask logits but for the last dim
        logits[terminal_mask] = -torch.inf
        return Categorical(logits=logits)

    def logflow(self, x):
        x = x.to(self.device)
        return self.logflow_mlp(x).flatten()

    def logZ(self):
        init_state = self.env.init()
        init_tensor = self.env.to_tensor(init_state).to(self.device)
        return self.logflow(init_tensor)

    def logits(self, x):
        x = x.to(self.device)
        return self.forward_mlp(x).masked_fill_(x == 1.0, -torch.inf)
    
    def logits_(self, x):
        x = x.to(self.device)
        return self.backward_mlp(x).masked_fill_(x == 1.0, -torch.inf)


def make_batch(env, sample_traj):
    traj, actions = sample_traj
    traj_length = [len(x) for x in actions]
    fwd_traj = sum([x[:-1] for x in traj], [])
    bck_traj = sum([x[1:] for x in traj], [])
    actions = sum(actions, [])
    dones = [state.is_terminal for state in bck_traj]
    actions = torch.tensor(actions)
    dones = torch.tensor(dones, dtype=torch.bool)
    fwd_batch = torch.vstack([env.to_tensor(x) for x in fwd_traj])
    bck_batch = torch.vstack([env.to_tensor(x) for x in bck_traj])
    traj_length = torch.tensor(traj_length, dtype=torch.long)
    return Batch(fwd_batch, bck_batch, actions, dones, traj_length)

def get_dbil_loss(model, batch, r_alpha, unif_pb=True):
    device = model.device
    bck_batch = batch.bck_batch.to(device)
    fwd_batch = batch.fwd_batch.to(device)
    actions = batch.actions.to(device)
    dones = batch.dones.to(device)
    traj_length = batch.traj_length.to(model.device)

    per_traj_rT = (r_alpha * (model.dim - traj_length)).float()
    index = torch.repeat_interleave(traj_length)
    rT = per_traj_rT[index]

    log_pf = model.forward(fwd_batch).log_prob(actions)
    if unif_pb:
        num_1s = bck_batch.sum(dim=1)
        num_1s = torch.where(dones, 1, num_1s)
        log_pb = torch.log(1 / num_1s)
    else:
        log_pb = model.backward(bck_batch).log_prob(actions)
    logF0 = model.logflow(fwd_batch)
    logF1 = model.logflow(bck_batch).detach()
    logF1 = torch.where(dones, rT, logF1)
    
    db_loss = (logF0 + log_pf - log_pb - logF1 - r_alpha).square()
    sum_db_loss = torch.zeros(len(traj_length), device=model.device)
    sum_db_loss.scatter_reduce_(dim=0, index=index, src=db_loss, reduce='sum')
    loss = sum_db_loss.mean()
    return loss

def get_sqil_loss(model, batch, r_alpha, unif_pb=False):
    device = model.device
    bck_batch = batch.bck_batch.to(device)
    fwd_batch = batch.fwd_batch.to(device)
    actions = batch.actions.to(device)
    dones = batch.dones.to(device)
    traj_length = batch.traj_length.to(model.device)

    per_traj_rT = (r_alpha * (model.dim - traj_length)).float()
    index = torch.repeat_interleave(traj_length)
    rT = per_traj_rT[index].to(model.device)
    Q = model.logits(fwd_batch)
    Q_ = model.logits_(bck_batch).detach()
    Q_sa = Q.gather(dim=1, index=actions.view(-1, 1)).squeeze(1)
    V = torch.logsumexp(Q_, dim=1)
    V = torch.where(dones, rT, V)
    if unif_pb:
        num_1s = bck_batch.sum(dim=1)
        num_1s = torch.where(dones, 1, num_1s)
        log_pb = torch.log(1 / num_1s)
    else:
        log_pb = 0.0
    sqil_loss = (Q_sa - V - r_alpha - log_pb).square()
    sum_loss = torch.zeros_like(traj_length).float()
    sum_loss.scatter_reduce_(dim=0, index=index, src=sqil_loss, reduce='sum')
    loss = sqil_loss.mean()
    return loss



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--algo', type=str, default='dbil')
    parser.add_argument('--save_dir', type=str, default='./')
    parser.add_argument('--r_alpha', type=int, default=1)
    parser.add_argument('--unif_pb', type=int, default=1)

    parser.add_argument('--B', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--n_iters', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=36)
    parser.add_argument('--k', type=int, default=4)
    parser.add_argument('--eval_every', type=int, default=2000)
    args = parser.parse_args()
    

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.B == 5:
        blocks = ['01100', '11000', '11110', '00100', '00011', '01101']
    elif args.B == 4:
        blocks = ['1110', '0011', '0110', '1100', '0100']
    else:
        raise
    B = args.B
    logger = get_logger(os.path.join(args.save_dir, 'logs.txt'))

    print(vars(args))
    print(device)
    print(blocks)
    print(flush=True)
    logger.info(str(vars(args)))

    def is_valid(k):
        return k[-1] == '1' and all([k[:-1][i * B:(i + 1) * B] in blocks for i in range(args.k)])

    def num_traj(k):
        return sum([i=='1' for i in k])


    env = BitFlipEnv(dim=B * args.k + 1)
    model = BitFlipPolicy(env)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    sampler = DataSampler(k=args.k, blocks=blocks)

    sample_size = args.batch_size // 2

    if args.algo == 'sqil':
        model.backward_mlp.load_state_dict(model.forward_mlp.state_dict())

    for itr in range(1, args.n_iters+1):
        data_traj = env.batch_sample_trajectory_from_terminal(sampler.sample(sample_size))
        model_traj = env.batch_sample_trajectory_from_model(model, sample_size)
        
        if args.algo == 'sqil':
            data_loss = get_sqil_loss(model, make_batch(env, data_traj), r_alpha=args.r_alpha, unif_pb=bool(args.unif_pb))
            model_loss = get_sqil_loss(model, make_batch(env, model_traj), r_alpha=-args.r_alpha, unif_pb=bool(args.unif_pb))
        elif args.algo == 'dbil':
            data_loss = get_dbil_loss(model, make_batch(env, data_traj), r_alpha=args.r_alpha, unif_pb=bool(args.unif_pb))
            model_loss = get_dbil_loss(model, make_batch(env, model_traj), r_alpha=-args.r_alpha, unif_pb=bool(args.unif_pb))
        
        loss = data_loss + model_loss
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if args.algo == 'sqil':
            src_params = model.forward_mlp.parameters()
            tgt_params = model.backward_mlp.parameters()
            for s, t in zip(src_params, tgt_params):
                t.data.mul_(0.95).add_(s.data * 0.05)
            
        if itr % args.eval_every == 0:
            probs = env.compute_state_probabilities(model)

            valid_probs = []
            # traj_lens = []
            for k, v in probs.items():
                if is_valid(k):
                    valid_probs.append(v)
                    # traj_lens.append(num_traj(k))
                    
            sum_probs = sum(valid_probs)
            # var_probs = np.var(valid_probs)
            # cor_probs = np.corrcoef(valid_probs, traj_lens)[0, 1]
            # cor_logps = np.corrcoef(np.log(valid_probs), traj_lens)[0, 1]
            d = dict(
                itr=itr,
                sum_probs=sum_probs,
                # var_probs=var_probs,
                # cor_probs=cor_probs,
                # cor_logps=cor_logps,
            )
            logger.info(str(d))
            sitr = str(itr).zfill(5)
            path = os.path.join(args.save_dir, f'model_{sitr}.pkl')

            model.cpu()
            torch.save(model, path)
            model.to(device)

            del probs
            del valid_probs