import math

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


def mask_percent(times, dim, p: float = 0):
    # time \in R^B
    mask, num = torch.zeros(times.size(0), dim, device=times.device), (times > 0).sum()
    mask[times > 0, :] = (torch.zeros(num, dim, device=times.device).uniform_() >= p).type(torch.float)
    return mask > 0


class multiTimeAttention(nn.Module):
    def __init__(self, input_dim, weight = None, nhidden=16, embed_time=16, num_heads=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.weight = weight
        self.nhidden = nhidden
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time),
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(input_dim * num_heads, nhidden)])

    def attention(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        dim = value.size(-1)
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(d_k)
        scores = scores.unsqueeze(-1).repeat_interleave(dim, dim=-1)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim=-2)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn * value.unsqueeze(-3), -2), p_attn

    def adjust_attention(self, value):
        if self.weight is None:
            return value
        else:
            # x = torch.zeros_like(value)
            # for c in range(value.size(-1)):
            #     x[..., c] = (self.weight[0] + self.weight[1]) * value[..., c] + self.weight[1] * value[..., torch.arange(value.size(-1)) != c].sum(dim=-1)
            # return (self.weight[0] * value + self.weight[1]
            #         * torch.sum(value, dim=-1, keepdim=True) * torch.ones((1, value.size(-1)), device=value.device))
            return value @ torch.block_diag(self.weight, self.weight)

    def forward(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        batch, seq_len, dim = value.size()
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        query, key = [l(x).view(x.size(0), -1, self.h, self.embed_time_k).transpose(1, 2)
                      for l, x in zip(self.linears, (query, key))]
        x, _ = self.attention(query, key, value, mask, dropout)
        x = self.adjust_attention(x)
        x = x.transpose(1, 2).contiguous() \
            .view(batch, -1, self.h * dim)
        return self.linears[-1](x)


class dec_mtan_rnn(nn.Module):
    def __init__(self, input_dim, query, latent_dim=2, nhidden=16,
                 embed_time=16, num_heads=1, learn_emb=False, device='cuda'):
        super(dec_mtan_rnn, self).__init__()
        self.embed_time = embed_time
        self.dim = input_dim
        self.device = device
        self.nhidden = nhidden
        self.query = query
        self.learn_emb = learn_emb
        self.att = multiTimeAttention(2 * nhidden, None, 2 * nhidden, embed_time, num_heads)
        self.gru_rnn = nn.GRU(latent_dim, nhidden, bidirectional=True, batch_first=True)
        self.z0_to_obs = nn.Sequential(
            nn.Linear(2 * nhidden, 50),
            nn.ReLU(),
            nn.Linear(50, input_dim))
        if learn_emb:
            self.periodic = nn.Linear(1, embed_time - 1)
            self.linear = nn.Linear(1, 1)

    def set_query(self, query):
        self.query = query

    def learn_time_embedding(self, tt):
        tt = tt.to(self.device)
        tt = tt.unsqueeze(-1)
        out2 = torch.sin(self.periodic(tt))
        out1 = self.linear(tt)
        return torch.cat([out1, out2], -1)

    def fixed_time_embedding(self, pos):
        d_model = self.embed_time
        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model)
        position = 48. * pos.unsqueeze(2)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(np.log(10.0) / d_model))
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe

    def forward(self, z, time_steps):
        out, _ = self.gru_rnn(z)
        # time_steps = time_steps.cpu()
        if self.learn_emb:
            query = self.learn_time_embedding(time_steps).to(self.device)
            key = self.learn_time_embedding(self.query.unsqueeze(0)).to(self.device)
        else:
            query = self.fixed_time_embedding(time_steps).to(self.device)
            key = self.fixed_time_embedding(self.query.unsqueeze(0)).to(self.device)
        out = self.att(query, key, out)
        out = self.z0_to_obs(out)
        return out


class Backbone(nn.Module):
    def __init__(self, timestamp, channel, n_hidden: int = 32, latent_dim: int = 32, weight: list = None,
                 query_length: int = 16, embed_times: int = 16, num_heads: int = 1,
                 windows: list = None, length: list = None, device: str = 'cpu',
                 mask: str = 'random', corr: np.ndarray = None, p: float = 0.1):
        super(Backbone, self).__init__()

        # B x T x C
        self.timestamp = timestamp
        self.channel = channel
        self.query_length = query_length
        self.device = device

        # # calc weight for agg
        if corr is None:
            assert weight is None or len(weight) == 2
            weight = np.array(weight if weight is not None else [1, 0])
            weight /= np.sum(weight)
            # self.weight[0] -= self.weight[1]
            self.weight = weight[1] * torch.ones(self.channel, self.channel)
            self.weight[torch.eye(self.channel, self.channel) == 1.] = weight[0]
            self.corr = None
        else:
            self.corr = torch.tensor(corr).float().to(self.device)
            self.weight = None

        # create mtans + recon
        self.n_hidden = n_hidden
        self.repr_dim = latent_dim * 2
        self.gru_rnn = nn.GRU(self.n_hidden, self.n_hidden, bidirectional=True, batch_first=True)
        self.hiddens_to_z0 = nn.Sequential(
            nn.Linear(2 * self.n_hidden, 50),
            nn.ReLU(),
            nn.Linear(50, self.repr_dim))

        self.periodic = nn.Linear(1, embed_times - 1)
        self.linear = nn.Linear(1, 1)
        # self.mtans = nn.ModuleList([multiTimeAttention(2 * self.channel, self.weight, self.n_hidden, embed_times, num_heads)])
        self.mtans = multiTimeAttention(2 * self.channel, self.weight if corr is None else self.corr, self.n_hidden, embed_times, num_heads)
        # self.recon = nn.ModuleList([nn.Linear(self.query_length, self.timestamp)])
        # self.recon = nn.Linear(self.query_length, self.timestamp)
        self.recon = dec_mtan_rnn(channel, None, self.repr_dim, n_hidden, embed_times, num_heads, True, device)
        self.representation = nn.Linear(self.repr_dim, self.channel)
        self.query = [torch.linspace(0, 1, self.query_length)]
        for idx, tsl in enumerate(length):
            # self.mtans.append(multiTimeAttention(2 * self.channel, self.weight, self.n_hidden, embed_times, num_heads))
            # self.recon.append(nn.Linear(self.query_length // (2 ** (idx + 1)), tsl))
            self.query.append(torch.linspace(0, 1, self.query_length // (2 ** (idx + 1))))
        if windows is not None:
            windows.insert(0, 0), length.insert(0, self.timestamp)
            self.windows, self.length = windows, length
        else:
            self.windows, self.length = [0], [self.timestamp]

        self.gru_rnn_repr = nn.GRU(self.repr_dim, self.repr_dim, batch_first=True)

        # create mask implementation
        self.mask = mask_percent
        self.percent = p

    def learn_time_embedding(self, tt):
        tt = tt.to(self.device)
        tt = tt.unsqueeze(-1)
        out2 = torch.sin(self.periodic(tt))
        out1 = self.linear(tt)
        return torch.cat([out1, out2], -1)

    def calc_zero_index_mask(self, ts):
        mask = ts < 0
        mask[:, 0] = True
        return mask & (ts == 0)

    def average_pooling(self, value, ts, mask, window, lts, dec_ts=None, zero_mask=None, further: bool = True):
        start, tmax, record_ind = 0, torch.max(ts) if further else 1, 0
        pooling = torch.zeros(value.size(0), lts, self.channel, device=self.device)
        ts_pooling, ts_mask = torch.zeros(value.size(0), lts, device=self.device), torch.zeros_like(pooling)
        dec_ts, dec_ts_pooling = ts.clone() if dec_ts is None else dec_ts, torch.zeros_like(ts_pooling)
        for i in range(math.ceil(tmax / window)):
            ind = ((ts > start) & (ts < start + window))
            if i == 0:
                if zero_mask is None:
                    zero_mask = self.calc_zero_index_mask(ts)
                ind |= zero_mask[:, :ts.size(-1)]
            # end if i (adjust ind (index == 0 & timestamp == 0))
            if ind.sum() != 0:
                # val, times, m = value.detach().clone(), ts.detach().clone(), mask.detach().clone()
                val, times, m, dec_times = value.clone(), ts.clone(), mask.clone(), dec_ts.clone()
                # val[m != 1], times[m != 1] = torch.nan, torch.nan
                val[m != 1] = torch.nan
                val[~ind], times[~ind], m[~ind], dec_times[~ind] = torch.nan, torch.nan, torch.nan, torch.nan
                pooling[:, record_ind, :] = torch.nanmean(val, dim=1)
                ts_pooling[:, record_ind] = torch.nanmean(times, dim=1)
                ts_mask[:, record_ind, :] = torch.nanmean(m, dim=1)
                dec_ts_pooling[:, record_ind] = torch.nanmean(dec_times, dim=1)

                if further:
                    record_mask = self.mask(ind.sum(dim=1), self.channel, self.percent)
                    # pooling[:, record_ind, :][~record_mask] = 0
                    # ts_pooling[:, record_ind][record_mask.sum(dim=1) == 0] = 0
                    ts_mask[:, record_ind, :][~record_mask] = 0

                record_ind += 1
            # end if ind (calc pooling)
            start += window
        # end for i
        pooling[torch.isnan(pooling)], ts_mask[torch.isnan(ts_mask)] = 0, 0
        ts_pooling[torch.isnan(ts_pooling)], dec_ts_pooling[torch.isnan(dec_ts_pooling)] = 0, 0
        ts_mask[ts_mask > 0] = 1.
        return pooling, ts_pooling, ts_mask, dec_ts_pooling

    def _encoder(self, mtans, query, key, val, m):
        out = mtans(query, key, val, m)
        out, _ = self.gru_rnn(out)
        out = self.hiddens_to_z0(out)
        return out

    def forward(self, x, time_steps, dec_timesteps=None):
        # time_steps, mask = time_steps.cpu(), x[..., self.channel:]
        time_steps, mask = time_steps, x[..., self.channel:]
        zero_mask, x = self.calc_zero_index_mask(time_steps), x[..., :self.channel]
        out, recon, x_pooling, recon_pooling, embedding = [], [], [], [None], None
        for ind in range(len(self.windows)):
            if ind != 0:
                val, ts, m, dec_ts = self.average_pooling(x, time_steps, mask, self.windows[ind], self.length[ind], dec_timesteps, zero_mask)
            else:
                val, ts, m, dec_ts = x, time_steps, mask, time_steps.clone() if dec_timesteps is None else dec_timesteps

            # # moving (0.1, 0.2, 0.3, 0, 1., 0, 0 => 0.1, 0.2, 0.3, 1., 0, 0, 0)
            val, m = torch.cat([val * m, m], dim=-1), torch.cat([m, m], dim=-1)
            key = self.learn_time_embedding(ts).to(self.device)
            query = self.learn_time_embedding(self.query[ind].unsqueeze(0)).to(self.device)
            # out.append(self._encoder(self.mtans[ind], query, key, val, m))
            # out.append(self._encoder(self.mtans, query, key, val, m))
            repr = self._encoder(self.mtans, query, key, val, m)
            out.append(self.gru_rnn_repr(repr)[1].squeeze(0))
            # recon.append(self.recon[ind](self.representation(out[-1]).transpose(1, 2)).transpose(1, 2))
            # recon.append(self.recon(self.representation(out[-1]).transpose(1, 2)).transpose(1, 2))
            self.recon.set_query(self.query[ind])
            # recon.append(self.recon(out[-1], ts))
            recon.append(self.recon(repr, dec_ts))
            # x_pooling.append(val[..., :self.channel])
            x_pooling.append(val.clone())
            if embedding is None:
                embedding = repr.clone()
            if not self.training:
                out, recon, ts = out[0], recon[0], ts[0]
                break
            elif ind < len(self.windows) - 1:
                val, _, _, _ = self.average_pooling(recon[-1], ts, m[..., :self.channel], self.windows[ind + 1],
                                                    self.length[ind + 1], None, zero_mask, False)
                recon_pooling.append(val.clone())
        return out, recon, x_pooling, recon_pooling, embedding


class CLS(nn.Module):
    def __init__(self, n_hidden: int, latent_dim: int, num_classes: int = 2):
        super(CLS, self).__init__()
        self.gru_rnn = nn.GRU(n_hidden, latent_dim, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, num_classes))

    def forward(self, z):
        _, out = self.gru_rnn(z)
        return self.classifier(out.squeeze(0))


class CLS2(nn.Module):
    def __init__(self, n_hidden: int, latent_dim: int, num_classes: int = 2):
        super(CLS2, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(n_hidden, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, num_classes))

    def forward(self, z):
        return self.classifier(z)


class CLSStatic(CLS2):
    def __init__(self, n_hidden: int, latent_dim: int, static_dim: int, static_hidden: int, num_classes: int = 2):
        super(CLSStatic, self).__init__(n_hidden + static_hidden, latent_dim, num_classes)
        self.static_fc = nn.Linear(static_dim, static_hidden)

    def forward(self, z, static=None):
        if static is not None:
            return self.classifier(torch.cat([z, self.static_fc(static)], dim=-1))
        else:
            return self.classifier(z)


if __name__ == '__main__':
    from datautils import load_physionet_data

    class Args:
        def __init__(self):
            self.n = 10000
            self.q = 0.016
            self.batch_size = 64
            self.classif = False


    data_object = load_physionet_data(Args(), 'cpu', 0.016)
    windows = data_object['windows'].detach().numpy().tolist()
    length = data_object['ts_length'].detach().numpy().tolist()
    channels = data_object['input_dim']
    timestamp = data_object['timestamp']
    query_length = 128

    net = Backbone(timestamp, channels, weight=[0.6, 0.4], query_length=query_length, embed_times=128, windows=windows, length=length)

    train_loader = data_object["train_dataloader"]
    test_loader = data_object["test_dataloader"]
    # val_loader = data_object["val_dataloader"]

    for train_batch in train_loader:
        pp = net(train_batch[..., :-1], train_batch[..., -1])
        print('ok')
