from collections import defaultdict, namedtuple
from tqdm import tqdm
from copy import deepcopy
import math
import random
import numpy as np


import torch
from torch import nn
from torch.distributions import Categorical
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon, FancyArrowPatch


State = namedtuple("State", ["x", "y", "is_terminal"])

def state_to_str(state):
    return str((state.x, state.y))

class TruncatedPascalEnv:
    
    def __init__(self, init_dim=5, last_dim=10, init_probs='uniform'):
        assert init_dim < last_dim
        assert init_probs == 'uniform' or sum(init_probs) == 1.0
        self.init_dim = init_dim
        self.last_dim = last_dim
        self.init_probs = init_probs

    def init(self):
        if self.init_probs == 'uniform':
            y = random.randint(0, self.init_dim - 1)
        else:
            y = random.choices(range(self.init_dim), weights=self.init_probs)[0]
        x = self.init_dim - 1 - y
        return State(x=x, y=y, is_terminal=False)
    
    def init_action_prop(self, y):
        if self.init_probs == 'uniform':
            p = 1 / self.init_dim 
        else:
            p = self.init_probs[int(y)]
        return p

    def to_tensor(self, state):
        assert state.x + state.y < self.last_dim
        tensor = torch.tensor([0.] * 2 * self.last_dim)
        tensor[state.x] = 1.
        tensor[self.last_dim + state.y] = 1.
        return tensor

    def forward_step(self, state, action):
        assert not state.is_terminal
        if action == 0: # left action
            assert state.x < self.last_dim - 1
            x, y = state.x + 1, state.y
        elif action == 1: # right action
            assert state.y < self.last_dim - 1
            x, y = state.x, state.y + 1
        is_terminal = x + y == self.last_dim - 1
        return State(x, y, is_terminal=is_terminal)
    
    def backward_step(self, state, action):
        assert state.x + state.y > self.init_dim - 1
        if action == 0: # left action
            assert state.x > 0
            x, y = state.x - 1, state.y
        elif action == 1:
            assert state.y > 0
            x, y = state.x, state.y - 1
        return State(x, y, False)

    def uniform_backward_probability(self, state):
        xp = 0.5
        if state.x == 0 or state.y == 0:
            xp = 1.
        return xp

    @torch.no_grad()
    def sample_trajectory_from_terminal(self, state, model=None):
        assert state.is_terminal
        traj = [state]
        actions = []
        while not (state.x + state.y == self.init_dim - 1):
            if model is None: # uniform
                xp = 0.5
                if state.x == 0:
                    xp = 0.
                elif state.y == 0:
                    xp = 1.
                action = 0 if xp > random.random() else 1
            else:
                tensor = self.to_tensor(state)
                action = model.backward(tensor.unsqueeze(0)).sample().item()
            state = self.backward_step(state, action)
            actions.append(action)
            traj.append(state)
        return traj[::-1], actions[::-1]
            
    @torch.no_grad()
    def sample_trajectory_from_model(self, model):
        state = self.init()
        traj = [state]
        actions = []
        for i in range(self.last_dim-self.init_dim):
            action = model.forward(self.to_tensor(state)).sample().item()
            state = self.forward_step(state, action)
            traj.append(state)
            actions.append(action)
        return traj, actions
    
    
    @torch.no_grad()
    def compute_state_probabilities(self, model):
        all_states = self.get_all_states()
        all_tensor = map(self.to_tensor, all_states)
        all_tensor = torch.vstack(list(all_tensor))
        
        probs = model.logits(all_tensor).softmax(dim=1)
        action_probs = dict(zip(all_states, probs))
        all_states_sorted = sorted(all_states, key=lambda state: state.x + state.y)
        
        state_probs = {str((state.x, state.y)): 0.0 for state in all_states}
        for state in self.get_all_init_states():
            state_probs[str((state.x, state.y))] = self.init_action_prop(state.y)
        
        for state in all_states_sorted:
            if state.is_terminal:
                continue
            p = state_probs[str((state.x, state.y))]
            left, right = action_probs[state].tolist()
            state_probs[str((state.x + 1, state.y))] += p * left
            state_probs[str((state.x, state.y + 1))] += p * right
        return state_probs
    
    def get_all_states(self):
        return [State(x, y, x + y == self.last_dim - 1) for x in range(self.last_dim) for y in range(self.last_dim-x) if x + y >= self.init_dim - 1]

    def get_all_init_states(self):
        return [State(i, self.init_dim - 1 - i, False) for i in range(self.init_dim)]
    
    def get_terminal_states(self):
        return [State(i, self.last_dim - 1 - i, True) for i in range(self.last_dim)]
    
    def to_transitions(self, traj, actions):
        tensors = [self.to_tensor(state) for state in traj]
        prev_states = torch.vstack(tensors[:-1])
        actions = torch.tensor(actions)
        next_states = torch.vstack(tensors[1:])
        dones = torch.tensor([state.is_terminal for state in traj[1:]])
        backward_probs = torch.tensor([self.uniform_backward_probability(s) for s,a in zip(traj[1:], actions)])
        return prev_states, actions, next_states, dones, backward_probs
    

    def visualize(self, model, cmap="inferno_r", annotate=False, ax=None, vmax=1.0, add_root=False, fontsize=8):
        state_probs = self.compute_state_probabilities(model)
        H = (np.sqrt(3)/2.0) * (self.last_dim - 1)
    
        def to_xy(x, y):
            X = (y - x) / 2.0 + (self.last_dim - 1) / 2.0
            Y = H - (np.sqrt(3)/2.0) * (x + y)
            return X, Y
    
        if ax is None:
            fix, ax = plt.subplots(figsize=(7,7))
    
        hexagon_centers={}
        pts, probs = [], []
        for state in self.get_all_states():
            x, y = state.x, state.y
            X, Y = to_xy(x,y)
            hexagon_centers[str((x,y))]=(X,Y)
            pts.append((X,Y))
            probs.append(state_probs[str((x,y))])
        probs = np.array(probs)
    
        r = 1.0 / np.sqrt(3)
    
        for (X,Y), p in zip(pts, probs):
            color = plt.get_cmap(cmap)(p/vmax)
            hexagon = RegularPolygon(
                (X, Y), numVertices=6, radius=r, orientation=0.0,
                facecolor=color, edgecolor=color, linewidth=0, 
            )
            ax.add_patch(hexagon)
    
            if annotate and p>0:
                ax.text(X, Y, f"{p:.2f}", ha="center", va="center", fontsize=fontsize)

        if add_root:
            X, Y = to_xy(0,0)
            if (X,Y) not in pts:
                pts.append((X,Y))
                hexagon = RegularPolygon(
                    (X, Y), numVertices=6, radius=r, orientation=0.0,
                    facecolor='white', edgecolor='white', linewidth=0, 
                )
                ax.add_patch(hexagon)
    
        xs, ys = zip(*pts)
        ax.set_xlim(min(xs)-1, max(xs)+1)
        ax.set_ylim(min(ys)-1, max(ys)+1)




class PascalPolicy(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self.dim = env.last_dim
        self.input_dim = 2*self.dim
        self.forward_mlp = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 2)
        )
        self.backward_mlp = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 2)
        )
        self.logflow_mlp = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        logits = self.forward_mlp(x)
        return Categorical(logits=logits)

    def backward(self, x):
        logits = self.backward_mlp(x)
        x_is_zero = x[:, 0] == 1
        y_is_zero = x [:, self.dim] == 1
        logits[x_is_zero, 0] = -torch.inf
        logits[y_is_zero, 1] = -torch.inf
        return Categorical(logits=logits)

    def logflow(self, x):
        return self.logflow_mlp(x).flatten()

    def logZ(self):
        x = torch.zeros(self.input_dim, dtype=torch.float)
        return self.logflow(x)

    def logits(self, x):
        return self.forward_mlp(x)


def get_sqil_loss(model, transitions, expert=True):
    prev_states, actions, next_states, dones, _ = transitions
    Q = model.logits(prev_states)
    Q_ = model.logits(next_states).detach()
    Q_sa = Q.gather(dim=1, index=actions.view(-1, 1)).squeeze(1)
    V = torch.logsumexp(Q_, dim=1)
    V = torch.where(dones, 0.0, V)
    r = 1.0 if expert else 0.0
    loss = (Q_sa - r - V).square().sum()
    return loss

def get_tbil_loss(model, transitions, reward, uniform_backward=True):
    prev_states, actions, next_states, dones, backward_probs = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    if uniform_backward:
        logPB = torch.log(backward_probs)
    else:
        bwd = model.backward(next_states)
        logPB = bwd.log_prob(actions)
    logZ = model.logZ()
    s0 = prev_states[0][:model.input_dim//2].nonzero().item()
    p0 = model.env.init_action_prop(s0)
    return (p0 + logZ + logPF.sum() - logPB.sum() - reward).square().sum()


def get_dbil_loss(model, transitions, reward, uniform_backward=True):
    prev_states, actions, next_states, dones, backward_probs = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    logF0 = model.logflow(prev_states)
    logF1 = model.logflow(next_states).detach()
    logF1 = torch.where(dones, 0.0, logF1)
    if uniform_backward:
        logPB = torch.log(backward_probs)
    else:
        bwd = model.backward(next_states)
        logPB = bwd.log_prob(actions)
        s0 = prev_states[0][:model.input_dim//2].nonzero().item()
        p0 = model.env.init_action_prop(s0)
        logF0[0] = model.logZ() + p0
    loss = (logF0 + logPF - logPB - logF1 - reward).square()
    return loss.sum()

def get_nll_loss(model, transitions):
    prev_states, actions, next_states, dones, _ = transitions
    fwd = model.forward(prev_states)
    logPF = fwd.log_prob(actions)
    return -logPF.sum()


def train(init_dim=5, last_dim=10, loss_type='tb', n_iter=5000, lr=0.001, uniform_backward=False, init_probs='uniform'):
    env = TruncatedPascalEnv(init_dim=init_dim, last_dim=last_dim, init_probs=init_probs)
    model = PascalPolicy(env)
    data = env.get_terminal_states()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    stats = defaultdict(list)
    for update in tqdm(range(n_iter)):
        expert_loss = 0.0
        for state in data:
            traj, actions = env.sample_trajectory_from_terminal(state, None if uniform_backward else model)
            transitions = env.to_transitions(traj, actions)
            if loss_type == 'tbil':
                expert_loss = expert_loss + get_tbil_loss(model, transitions, 1, uniform_backward)
            elif loss_type == 'dbil':
                expert_loss = expert_loss + get_dbil_loss(model, transitions, 1, uniform_backward)
            elif loss_type == 'sqil':
                expert_loss = expert_loss + get_sqil_loss(model, transitions, expert=True)
            elif loss_type == 'nll':
                expert_loss = expert_loss + get_nll_loss(model, transitions)

        policy_loss = 0.0
        for i in range(len(data)):
            traj, actions = env.sample_trajectory_from_model(model)
            transitions = env.to_transitions(traj, actions)
            if loss_type == 'tbil':
                policy_loss = policy_loss + get_tbil_loss(model, transitions, -1, uniform_backward)
            elif loss_type == 'dbil':
                policy_loss = policy_loss + get_dbil_loss(model, transitions, -1, uniform_backward)
            elif loss_type == 'sqil':
                policy_loss = policy_loss + get_sqil_loss(model, transitions, expert=False)
    
        loss = policy_loss + expert_loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()    
        stats['loss'].append(loss.item())
    return env, model, stats