from collections import defaultdict, namedtuple
from tqdm import tqdm
from copy import deepcopy
import math
import random
import numpy as np


import torch
from torch import nn
from torch.distributions import Categorical
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon, FancyArrowPatch


State = namedtuple("State", ["x", "y", "is_terminal"])

def state_to_str(state):
    return str((state.x, state.y))

class PascalEnv:
    
    def __init__(self, dim=5):
        self.dim = dim

    def init(self):
        return State(x=0, y=0, is_terminal=False)

    def to_tensor(self, state):
        assert state.x + state.y < self.dim
        tensor = torch.tensor([0.] * 2 * self.dim)
        tensor[state.x] = 1.
        tensor[self.dim + state.y] = 1.
        return tensor

    def forward_step(self, state, action):
        assert not state.is_terminal
        if action == 0: # left action
            assert state.x < self.dim - 1
            x, y = state.x + 1, state.y
        elif action == 1: # right action
            assert state.y < self.dim - 1
            x, y = state.x, state.y + 1
        is_terminal = x + y == self.dim - 1
        return State(x, y, is_terminal=is_terminal)
    
    def backward_step(self, state, action):
        assert state.x + state.y != 0
        if action == 0: # left action
            assert state.x > 0
            x, y = state.x - 1, state.y
        elif action == 1:
            assert state.y > 0
            x, y = state.x, state.y - 1
        return State(x, y, False)

    def uniform_backward_probability(self, state):
        xp = 0.5
        if state.x == 0 or state.y == 0:
            xp = 1.
        return xp

    def sample_trajectory_from_terminal(self, state):
        assert state.is_terminal
        traj = [state]
        actions = []
        while not (state.x == 0 and state.y == 0):
            xp = 0.5
            if state.x == 0:
                xp = 0.
            elif state.y == 0:
                xp = 1.
            action = 0 if xp > random.random() else 1
            state = self.backward_step(state, action)
            actions.append(action)
            traj.append(state)
        return traj[::-1], actions[::-1]
            
    @torch.no_grad()
    def sample_trajectory_from_model(self, model):
        state = self.init()
        traj = [state]
        actions = []
        for i in range(self.dim-1):
            action = model.forward(self.to_tensor(state)).sample().item()
            state = self.forward_step(state, action)
            traj.append(state)
            actions.append(action)
        return traj, actions
    
    @torch.no_grad()
    def compute_state_probabilities(self, model):
        all_states = self.get_all_states()
        all_tensor = map(self.to_tensor, all_states)
        all_tensor = torch.vstack(list(all_tensor))
        
        probs = model.logits(all_tensor).softmax(dim=1)
        action_probs = dict(zip(all_states, probs))
        all_states_sorted = sorted(all_states, key=lambda state: state.x + state.y)
        
        state_probs = {str((state.x, state.y)): 0.0 for state in all_states}
        state_probs[str((0,0))] = 1.0
        for state in all_states_sorted:
            if state.is_terminal:
                continue
            p = state_probs[str((state.x, state.y))]
            left, right = action_probs[state].tolist()
            state_probs[str((state.x + 1, state.y))] += p * left
            state_probs[str((state.x, state.y + 1))] += p * right
        return state_probs
    
    def get_all_states(self):
        return [State(i, j, i + j == self.dim - 1) for i in range(self.dim) for j in range(self.dim-i)]

    def get_terminal_states(self):
        return [State(i, self.dim - 1 - i, True) for i in range(self.dim)]
    
    def to_transitions(self, traj, actions):
        tensors = [self.to_tensor(state) for state in traj]
        prev_states = torch.vstack(tensors[:-1])
        actions = torch.tensor(actions)
        next_states = torch.vstack(tensors[1:])
        dones = torch.tensor([state.is_terminal for state in traj[1:]])
        backward_probs = torch.tensor([self.uniform_backward_probability(s) for s,a in zip(traj[1:], actions)])
        return prev_states, actions, next_states, dones, backward_probs
    

    def visualize(self, model, cmap="inferno_r", annotate=False, ax=None, vmax=1.0):
        state_probs = self.compute_state_probabilities(model)
        H = (np.sqrt(3)/2.0) * (self.dim - 1)

        def to_xy(x, y):
            X = (y - x) / 2.0 + (self.dim - 1) / 2.0
            Y = H - (np.sqrt(3)/2.0) * (x + y)
            return X, Y

        if ax is None:
            fix, ax = plt.subplots(figsize=(7,7))

        hexagon_centers={}
        pts, probs = [], []
        for x in range(self.dim):
            for y in range(self.dim-x):
                X, Y = to_xy(x,y)
                hexagon_centers[str((x,y))]=(X,Y)
                pts.append((X,Y))
                probs.append(state_probs[str((x,y))])
        probs = np.array(probs)

        r = 1.0 / np.sqrt(3)

        for (X,Y), p in zip(pts, probs):
            color = plt.get_cmap(cmap)(0.0 if vmax==0 else p/vmax)
            hexagon = RegularPolygon(
                (X, Y), numVertices=6, radius=r, orientation=0.0,
                facecolor=color, edgecolor=color, linewidth=0
            )
            ax.add_patch(hexagon)

            if annotate and p>0:
                ax.text(X, Y, f"{p:.2f}", ha="center", va="center", fontsize=8)

        xs, ys = zip(*pts)
        ax.set_xlim(min(xs)-1, max(xs)+1)
        ax.set_ylim(min(ys)-1, max(ys)+1)
    

class PascalPolicy(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self.dim = env.dim
        self.input_dim = 2*self.dim
        self.forward_mlp = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 2)
        )
        self.backward_mlp = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 2)
        )
        self.logflow_mlp = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        logits = self.forward_mlp(x)
        return Categorical(logits=logits)

    def backward(self, x):
        logits = self.backward_mlp(x)
        x_is_zero = x[:, 0] == 1
        y_is_zero = x [:, self.dim] == 1
        logits[x_is_zero, 0] = -torch.inf
        logits[y_is_zero, 1] = -torch.inf
        return Categorical(logits=logits)

    def logflow(self, x):
        return self.logflow_mlp(x).flatten()

    def logZ(self):
        init_state = self.env.init()
        init_tensor = self.env.to_tensor(init_state)
        return self.logflow(init_tensor)

    def logits(self, x):
        return self.forward_mlp(x)


def get_sqil_loss(model, transitions, expert=True):
    prev_states, actions, next_states, dones, _ = transitions
    Q = model.logits(prev_states)
    Q_ = model.logits(next_states).detach()
    Q_sa = Q.gather(dim=1, index=actions.view(-1, 1)).squeeze(1)
    V = torch.logsumexp(Q_, dim=1)
    V = torch.where(dones, 0.0, V)
    r = 1.0 if expert else 0.0
    loss = (Q_sa - r - V).square().sum()
    return loss

def get_tbil_loss(model, transitions, reward):
    prev_states, actions, next_states, dones, backward_probs = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    logPB = torch.log(backward_probs)
    logZ = model.logZ()
    return (logZ + logPF.sum() - logPB.sum() - reward).square().sum()


def get_dbil_loss(model, transitions, reward):
    prev_states, actions, next_states, dones, backward_probs = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    logPB = torch.log(backward_probs)
    logF0 = model.logflow(prev_states)
    logF1 = model.logflow(next_states).detach()
    logF1 = torch.where(dones, 0.0, logF1)
    loss = (logF0 + logPF - logPB - logF1 - reward).square()
    return loss.sum()


def train(dim, loss_type='tb', n_iter=5000, lr=0.001):
    env = PascalEnv(dim=dim)
    model = PascalPolicy(env)
    data = env.get_terminal_states()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    stats = defaultdict(list)
    for update in tqdm(range(n_iter)):
        loss = 0.0
        for state in data:
            traj, actions = env.sample_trajectory_from_terminal(state)
            transitions = env.to_transitions(traj, actions)
            if loss_type == 'tb':
                loss = loss + get_tbil_loss(model, transitions, 1)
            elif loss_type == 'db':
                loss = loss + get_dbil_loss(model, transitions, 1)
            elif loss_type == 'sqil':
                loss = loss + get_sqil_loss(model, transitions, expert=True)
            
        for i in range(len(data)):
            traj, actions = env.sample_trajectory_from_model(model)
            transitions = env.to_transitions(traj, actions)
            if loss_type == 'tb':
                loss = loss + get_tbil_loss(model, transitions, -1)
            elif loss_type == 'db':
                loss = loss + get_dbil_loss(model, transitions, -1)
            elif loss_type == 'sqil':
                loss = loss + get_sqil_loss(model, transitions, expert=False)
    
        loss = loss / (2 * len(data))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()    
        stats['loss'].append(loss.item())
    return env, model, stats