import torch
from torch import nn
from torch.nn import functional as F

from ddlm.time.low_discrepency_sampling import get_t
from matplotlib import pyplot as plt
import seaborn as sns


class TimeWrapping(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_bins = config.num_bins
        self.l_t = nn.Parameter(
            -torch.ones(config.num_bins - 1)
        )  # -1 as we parametrize left edges
        # self.l_u = nn.Parameter((1 - torch.linspace(0, 3, config.num_bins)) - 1.5)
        self.l_u = nn.Parameter(-torch.ones(config.num_bins - 1))
        self.register_buffer("epsilon", torch.tensor(config.time_wrapping_epsilon))

    @staticmethod
    def _prepend_zero(cdf):
        cdf_prep = torch.zeros(cdf.size(0) + 1, device=cdf.device)
        cdf_prep[1:] = cdf
        return cdf

    def forward(self, t):
        w_t = F.softmax(self.l_t, -1)
        w_t = (w_t + self.epsilon) / (1 + self.epsilon * (self.num_bins - 1))
        w_t = self._prepend_zero(w_t)
        w_u = F.softmax(self.l_u, -1)
        w_u = (w_u + self.epsilon) / (1 + self.epsilon * (self.num_bins - 1))
        w_u = self._prepend_zero(w_u)

        e_t = w_t.cumsum(-1)
        e_u = w_u.cumsum(-1)
        c_idx_u = (t.view(-1, 1) > e_u).sum(-1)
        c_idx_u[c_idx_u > (w_u.size(0) - 1)] = w_u.size(0) - 1

        u = w_t[c_idx_u] / w_u[c_idx_u] * (t - e_u[c_idx_u]) + e_t[c_idx_u]  # F^{-1}(t)

        w_exp_u = torch.exp(self.l_u)
        w_exp_u = self._prepend_zero(w_exp_u)
        e_exp_u = w_exp_u.cumsum(-1)
        c_idx_t = (t.view(-1, 1) > e_t).sum(-1)
        c_idx_t[c_idx_t > (w_u.size(0) - 1)] = w_u.size(0) - 1
        ent_u = w_exp_u[c_idx_t] / w_t[c_idx_t] * (t - e_t[c_idx_t]) + e_exp_u[c_idx_t]

        return u, ent_u

    @torch.no_grad()
    def plot_F_u(self):
        w_t = F.softmax(self.l_t, -1)
        w_t = (w_t + self.epsilon) / (1 + self.epsilon * (self.num_bins - 1))
        w_t = self._prepend_zero(w_t)
        w_u = F.softmax(self.l_u, -1)
        w_u = (w_u + self.epsilon) / (1 + self.epsilon * (self.num_bins - 1))
        w_u = self._prepend_zero(w_u)

        e_t = w_t.cumsum(-1)
        e_u = w_u.cumsum(-1)

        t = torch.linspace(0, 1, 1000, device=self.l_u.device)

        c_idx = (t.view(-1, 1) > e_t).sum(-1)
        c_idx[c_idx > (w_u.size(0) - 1)] = w_u.size(0) - 1

        u = w_u[c_idx] / w_t[c_idx] * (t - e_t[c_idx]) + e_u[c_idx]

        data = [[x, y] for (x, y) in zip(t.detach().cpu().numpy(), u.detach().cpu())]
        return data

    @torch.no_grad()
    def plot_F_prime_u(self):
        w_t = F.softmax(self.l_t, -1)
        w_t = (w_t + self.epsilon) / (1 + self.epsilon * (self.num_bins - 1))
        w_t = self._prepend_zero(w_t)
        w_u = F.softmax(self.l_u, -1)
        w_u = (w_u + self.epsilon) / (1 + self.epsilon * (self.num_bins - 1))
        w_u = self._prepend_zero(w_u)

        e_t = w_t.cumsum(-1)
        e_u = w_u.cumsum(-1)

        w_exp_u = torch.exp(self.l_u)
        w_exp_u = self._prepend_zero(w_exp_u)
        e_exp_u = w_exp_u.cumsum(-1)

        t = torch.linspace(0, 1, 1000, device=self.l_u.device)

        c_idx = (t.view(-1, 1) > e_t).sum(-1)
        c_idx[c_idx > (w_t.size(0) - 1)] = w_t.size(0) - 1

        ent_u = w_exp_u[c_idx] / w_t[c_idx] * (t - e_t[c_idx]) + e_exp_u[c_idx]

        data = [
            [x, y] for (x, y) in zip(t.detach().cpu().numpy(), ent_u.detach().cpu())
        ]

        return data
