import torch
import torch.nn as nn


class DiscreteVAEEncoder(nn.Module):
    """
        Encode single vectors into discrete latent vectors.
    """
    base_config = {
        "encoder": {
            "n_layers": 1,
            "hidden_dim": 256,
            "activation": nn.ReLU(),
            "normalization": None,
        },
        "decoder": {
            "n_layers": 1,
            "hidden_dim": 256,
            "activation": nn.ReLU(),
            "normalization": None,
        },
    }

    def __init__(
        self, input_shape, config=base_config, categoricals=8, latents=8
    ) -> None:

        super(DiscreteVAEEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_shape, config["encoder"]["hidden_dim"]),
            config["encoder"]["activation"],
            *[
                nn.Sequential(
                    nn.Linear(
                        config["encoder"]["hidden_dim"], config["encoder"]["hidden_dim"]
                    ),
                    config["encoder"]["activation"],
                )
                for i in range(config["encoder"]["n_layers"])
            ],
            nn.Linear(config["encoder"]["hidden_dim"], categoricals * latents),
        )

        self.latent_dim = latents
        self.categoricals = categoricals

    def logits(self, observation):
        enc = self.encoder(observation)
        enc = enc.reshape((-1, self.categoricals, self.latent_dim))
        return enc

    def encode(self, observation, return_dist=False):
        logits = self.logits(observation)
        base_dist = torch.distributions.OneHotCategoricalStraightThrough(
            logits=logits
        )
        distribution = torch.distributions.Independent(base_dist, 1)

        discretized = (
            distribution.rsample()
        )
        if return_dist:
            return discretized.reshape((discretized.shape[0], -1)), base_dist
        return discretized.reshape((discretized.shape[0], -1))

    def dist(self, observation):
        logits = self.logits(observation)
        distribution = torch.distributions.OneHotCategoricalStraightThrough(
            logits=logits)
        return distribution


class DiscreteSequenceEncoder(nn.Module):
    """
        Encode sequences of observations into latent goals.
    """
    base_config = {
        "rnn": {
            "context_size": 256,
        },
        "vae": {
            "encoder": {
                "n_layers": 1,
                "hidden_dim": 256,
                "activation": nn.ReLU(),
                "normalization": None,
            },
        },
    }

    def __init__(
        self, input_shape, config=base_config, categoricals=8, latents=8
    ):

        super(DiscreteSequenceEncoder, self).__init__()

        self.context_size = config["rnn"]["context_size"]

        self.input_shape = input_shape
        self.encoder = nn.GRU(
            input_size=self.input_shape, hidden_size=self.context_size, batch_first=True, num_layers=1
        )
        self.vae = DiscreteVAEEncoder(
            input_shape=self.context_size,
            config=config["vae"],
            categoricals=categoricals,
            latents=latents
        )

        self.categoricals = categoricals
        self.latents = latents

    def encode(self, sequence, return_dist=False, return_h=False):

        if len(sequence.size()) == 2:
            sequence = sequence.unsqueeze(1)

        _, h = self.encoder(sequence)

        out = self.vae.encode(h[-1], return_dist)
        if return_h:
            return out, h
        return out

    def context(self, sequence):
        if len(sequence.size()) == 2:
            sequence = sequence.unsqueeze(1)

        _, h = self.encoder(sequence)
        return h[-1]

    def dist(self, sequence):
        if len(sequence.size()) == 2:
            sequence = sequence.unsqueeze(1)

        B, T, _ = sequence.shape
        h, _ = self.encoder(sequence)
        h = h[-1, -1, :]  # Take the last hidden state as the representation
        return self.vae.dist(h)


class ObservationDecoder(nn.Module):
    """
        Reconstruct observations from latent goals.
    """
    base_config = {
        "n_layers": 1,
        "hidden_dim": 256,
        "activation": nn.ReLU(),
        "normalization": None,
    }

    def __init__(self, observation_shape, goal_shape, config=base_config):
        super(ObservationDecoder, self).__init__()

        activation = config["activation"]
        hidden_dim = config["hidden_dim"]
        if config["normalization"] == "layernorm":
            norm = nn.LayerNorm(hidden_dim)
        else:
            norm = nn.Identity()

        self.net = nn.Sequential(
            nn.Linear(goal_shape, hidden_dim),
            norm,
            activation,
            *[
                nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                              norm, activation)
                for i in range(config["n_layers"])
            ],
            nn.Linear(hidden_dim, 2*observation_shape),
        )

    def forward(self, goal):
        out = self.net(goal)
        loc, scale = torch.chunk(out, 2, dim=-1)
        scale = torch.exp(scale)
        return torch.distributions.Normal(loc, scale)


class EnergyModel(nn.Module):
    """
        Energy model used for subset similarity.
    """
    base_config = {
        "n_layers": 1,
        "hidden_dim": 256,
        "activation": nn.ReLU(),
        "normalization": None,
    }

    def __init__(self, goal_shape, config=base_config):
        super(EnergyModel, self).__init__()

        activation = config["activation"]
        hidden_dim = config["hidden_dim"]
        if config["normalization"] == "layernorm":
            norm = nn.LayerNorm(hidden_dim)
        else:
            norm = nn.Identity()

        self.net = nn.Sequential(
            nn.Linear(2*goal_shape, hidden_dim),
            norm,
            activation,
            *[
                nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                              norm, activation)
                for i in range(config["n_layers"])
            ],
            nn.Linear(hidden_dim, 1),
        )

    def energy(self, x, y):
        energy = self.net(torch.cat([x, y], dim=-1))
        return energy


class GoalAbstraction(torch.nn.Module):
    """
        Encode observation sequences into latent goals, trained with contrastive learning using subset similarity.
    """
    base_config = {
        "batch_size": 64,
        "rnn_steps": 10,
        "threshold": 0.95,
        "n_steps": 1,
        "merged_trajectories": True,
        "kl_warmup": 50000,
        "kl_max": 5.0,
        "loss": "bce",
        "trajectory_length": 32,
        "update_interval": 32,
    }

    def __init__(
        self,
        device,
        subset_energy,
        goal_encoder,
        observation_decoder=None,
        config=base_config,
    ) -> None:
        super(GoalAbstraction, self).__init__()
        self.config = config

        self.subset_energy = subset_energy
        self.goal_encoder = goal_encoder
        self.observation_decoder = observation_decoder
        self.device = device

        modules = []
        if self.observation_decoder is not None:
            modules.append(self.observation_decoder)

        self.opt = torch.optim.Adam([
            {'params': nn.ModuleList(modules).parameters(), 'lr': 1e-3},
            {'params': self.subset_energy.parameters(), 'lr': 1e-3},
            {'params': self.goal_encoder.parameters(), 'lr': 1e-3},
        ])

    def encode(self, observations):
        return self.goal_encoder.encode(observations)

    def update_observation_decoder(self, observations):
        with torch.no_grad():
            # select random subsequence
            B, T = observations.shape[0], observations.shape[1]
            # T//3: for positive training, at least 3 trajectories must fit into the full sequence
            t = torch.randint(1, T//3, (), device=self.device).item()

            start_idx = (torch.rand((B,), device=self.device)
                         * (T-t)).reshape((B, 1)).int()
            seq = torch.arange(0, t, device=self.device).reshape((1, t))
            indices = start_idx + seq
            batch_idx = torch.arange(B, device=self.device).unsqueeze(1)

            sequences = observations[batch_idx, indices]

            idx = torch.randint(0, t, (B, 1), device=self.device)
            selected = sequences[batch_idx, idx]

        goals = self.goal_encoder.encode(sequences)  # B,G

        dist = self.observation_decoder(goals)
        rec_loss = -dist.log_prob(selected.squeeze(1).detach()).mean()

        return rec_loss, {"observation_decoder": rec_loss.detach()}

    def update(self, buffer, step):
        total_loss = torch.zeros((1,), device=self.device)
        T = torch.randint(self.config["min_traj_len"],
                          self.config["max_traj_len"], (1,)).item()
        with torch.no_grad():
            while T >= self.config["min_traj_len"]:
                transitions = buffer.sample(
                    self.config["batch_size"], T, to_device=self.device, keys=None, unique=True)
                if transitions is None:
                    T = T - 1
                    continue
                break

            if transitions is None:
                return {}

        return_steps = min(self.config.get("max_subtraj_len", T//3), T//3)

        total_loss, losses = self.train(
            transitions["observation"].detach(),
            step=step,
            return_steps=return_steps
        )

        if self.observation_decoder is not None:
            observation_decoder_steps = self.config.get(
                "max_observation_decoder_steps", None)
            # disable info rec after certain amount of steps
            if observation_decoder_steps is None or step < observation_decoder_steps:
                info_loss, info_dict = self.update_observation_decoder(
                    transitions["observation"].detach())
                total_loss += info_loss
                losses = losses | info_dict

        self.opt.zero_grad(set_to_none=True)
        total_loss.backward()
        self.opt.step()
        return losses

    def shuffle_batch(self,  x):
        return x[torch.randperm(x.shape[0])]

    def combine_samples(self, samples):
        return torch.cat([s[0] for s in samples], dim=0), torch.cat(
            [s[1] for s in samples], dim=0
        )

    def symmetrize(self, a, b):
        if self.config.get("symmetric", True):
            return [a, b], [b, a]
        else:
            return [a, b],

    def train_on_samples(self, positive_samples, negative_samples):

        positive_samples_a, positive_samples_b = self.combine_samples(positive_samples)
        negative_samples_a, negative_samples_b = self.combine_samples(negative_samples)

        if self.config.get("symmetric_negatives", False):
            negative_samples_a = torch.cat(
                [negative_samples_a, positive_samples_b], dim=0)
            negative_samples_b = torch.cat(
                [negative_samples_b, positive_samples_a], dim=0)
            neg_scale = 0.5  # twice as many examples, so scale down the loss
        else:
            neg_scale = 1.0
        positive_energy = self.subset_energy.energy(
            positive_samples_a, positive_samples_b)
        negative_energy = self.subset_energy.energy(
            negative_samples_a, negative_samples_b)

        criterion = nn.BCEWithLogitsLoss()

        negative_loss = criterion(
            negative_energy, torch.zeros_like(negative_energy)
        )
        positive_loss = criterion(
            positive_energy, torch.ones_like(positive_energy)
        )

        return positive_loss, neg_scale * negative_loss, positive_energy, negative_energy

    def sample_subset_sequences(self, sequences, max_seq_len):
        B, T = sequences.shape[0], sequences.shape[1]

        lens = torch.randint(1, max_seq_len + 1, (4,))
        len_a, len_b = sorted(lens[:2].tolist(), reverse=True)
        len_c, len_r = lens[2].item(), lens[3].item()

        if torch.rand(1).item() > 0.5:
            start_a = torch.randint(0, T - len_a - len_c + 1, (1,)).item()
            start_c = torch.randint(
                start_a + len_a, T - len_c + 1, (1,)).item()
        else:
            start_a = torch.randint(len_c, T - len_a + 1, (1,)).item()
            start_c = torch.randint(0, start_a - len_c + 1, (1,)).item()

        start_b = start_a + torch.randint(0, len_a - len_b + 1, (1,)).item()
        start_r = torch.randint(0, T - len_r + 1, (1,)).item()

        sequence_a = sequences[:, start_a:start_a + len_a]
        sequence_b = sequences[:, start_b:start_b + len_b]
        sequence_c = sequences[:, start_c:start_c + len_c]
        sequence_r = sequences[:, start_r:start_r + len_r]

        return sequence_a, sequence_b, sequence_c, sequence_r

    def get_zero_goal(self, batch_size):
        zero_tensor = torch.zeros(
            (batch_size, self.config["latent_space"][0], self.config["latent_space"][1]), device=self.device)
        zero_tensor[:, :, 0] = 1.0
        return zero_tensor.reshape((-1, self.config["latent_space"][0]*self.config["latent_space"][1]))

    def get_one_goal(self, batch_size):
        one_tensor = torch.zeros(
            (batch_size, self.config["latent_space"][0], self.config["latent_space"][1]), device=self.device)
        one_tensor[:, :, -1] = 1.0
        return one_tensor.reshape((-1, self.config["latent_space"][0]*self.config["latent_space"][1]))

    def train(self, sequences, step, return_steps):
        abstractor_loss = torch.zeros((1,), device=self.device)

        B = sequences.shape[0]
        F = sequences.shape[2:]
        with torch.no_grad():
            sequence_a, sequence_b, sequence_c, sequence_r = self.sample_subset_sequences(
                sequences, max_seq_len=return_steps)
            len_a = sequence_a.shape[1]

            """
                Example batch: [aaaaaa]  [cccccccc]
                                [bb]   [rr]
                => b is a subset of a; c is strictly non-overlapping with a and b; r is random
            """
            sequence_ar = torch.cat([sequence_a, sequence_r], dim=1)
            ind_a = sequence_a.reshape((B * len_a, *F))
            ind_ar = sequence_ar.reshape((B * sequence_ar.shape[1], *F))

        seq_enc_a, dist_a = self.goal_encoder.encode(
            sequence_a, return_dist=True)
        seq_enc_b = self.goal_encoder.encode(sequence_b)
        seq_enc_c = self.goal_encoder.encode(sequence_c)
        seq_enc_r = self.goal_encoder.encode(sequence_r)
        seq_enc_ar = self.goal_encoder.encode(sequence_ar)

        # Also use all 1 step encodngs
        ind_enc_a, _ = self.goal_encoder.encode(ind_a, return_dist=True)
        ind_enc_ar = self.goal_encoder.encode(ind_ar)

        # Repeat sequence encoding to same length
        seq_enc_a_rep = (
            seq_enc_a.reshape((B, 1, -1))
            .repeat((1, len_a, 1))
            .reshape((B * len_a, -1))
        )

        seq_enc_ar_rep = (
            seq_enc_ar.reshape((B, 1, -1))
            .repeat((1, sequence_ar.shape[1], 1))
            .reshape((B * sequence_ar.shape[1], -1))
        )

        # Shuffle
        seq_enc_a_s = self.shuffle_batch(seq_enc_a)

        ind_enc_a_s = self.shuffle_batch(ind_enc_a)

        seq_enc_ar_rep_s = self.shuffle_batch(seq_enc_ar_rep)
        seq_enc_a_rep_s = self.shuffle_batch(seq_enc_a_rep)

        positive_samples = [
            *self.symmetrize(seq_enc_b, seq_enc_a),  # b subset of a
        ]

        negative_samples = [
            *self.symmetrize(seq_enc_a, seq_enc_c),  # a not subset of c
        ]

        if self.config.get("individual", True):
            positive_ind = [
                # individual states
                *self.symmetrize(ind_enc_a, seq_enc_a_rep),
            ]
            negative_ind = [
                # individual states, shuffled
                *self.symmetrize(ind_enc_a, seq_enc_a_rep_s),
                # individual states, shuffled
                *self.symmetrize(ind_enc_a_s, seq_enc_a_rep),
            ]

            positive_samples += positive_ind
            negative_samples += negative_ind

        if self.config["merged_trajectories"]:
            mss_positive_samples = [
                *self.symmetrize(ind_enc_ar, seq_enc_ar_rep),
                *self.symmetrize(seq_enc_a, seq_enc_ar),
                *self.symmetrize(seq_enc_r, seq_enc_ar),
            ]

            mss_negative_samples = [
                *self.symmetrize(seq_enc_c, seq_enc_ar),
                *self.symmetrize(seq_enc_a_s, seq_enc_ar),
                *self.symmetrize(ind_enc_ar, seq_enc_ar_rep_s),
            ]

            positive_samples += mss_positive_samples
            negative_samples += mss_negative_samples

        if self.config.get("zero_goal", False):
            zero_a = self.get_zero_goal(seq_enc_a.shape[0])
            zero_pos_samples = [
                *self.symmetrize(zero_a, seq_enc_a),
            ]
            zero_neg_samples = [
                *self.symmetrize(seq_enc_a, zero_a),
            ]
            positive_samples += zero_pos_samples
            negative_samples += zero_neg_samples
        if self.config.get("one_goal", False):
            one_a = self.get_one_goal(seq_enc_a.shape[0])
            one_pos_samples = [
                *self.symmetrize(seq_enc_a, one_a),
            ]
            one_neg_samples = [
                *self.symmetrize(one_a, seq_enc_a),
            ]
            positive_samples += one_pos_samples
            negative_samples += one_neg_samples

        positive_loss, negative_loss, positive_energy, negative_energy = (
            self.train_on_samples(
                positive_samples,
                negative_samples,
            )
        )

        abstractor_loss += positive_loss + negative_loss

        kl_loss = self.kl_warmup(dist_a, step)

        abstractor_loss += kl_loss

        return abstractor_loss, {
            "abstr_positive_energy": positive_energy.mean().detach(),
            "abstr_negative_energy": negative_energy.mean().detach(),
            "abstr_positive_loss": positive_loss.mean().detach(),
            "abstr_negative_loss": negative_loss.mean().detach(),
            "abstr_kl_loss": kl_loss.mean().detach(),
            "abstr_ent": dist_a.entropy().mean().detach(),
        }

    def kl_warmup(self, dist, step):
        kl_loss = torch.zeros((1,), device=self.device)
        prior = torch.distributions.OneHotCategorical(
            logits=torch.ones_like(dist.probs))
        if self.config["kl_warmup"] > 0 and step < self.config["kl_warmup"]:
            kl = torch.distributions.kl.kl_divergence(dist, prior)
            warmup = step / self.config["kl_warmup"]  # , 1.0)
            kl_loss = torch.max(kl, torch.ones_like(
                kl) * self.config["kl_max"] * warmup).mean()

        elif self.config["kl_target"] > 0:
            kl = torch.distributions.kl.kl_divergence(dist, prior)
            kl_loss = torch.max(kl, torch.ones_like(
                kl) * self.config["kl_target"]).mean()
        return kl_loss


class HierarchicalPolicy(nn.Module):
    """
        Parametrized goal, trained with contrastive learning to learn to encode high rewarding states.
    """

    def __init__(self, input_shape, goal_shape, goal_rep, categoricals, latents, device, batch_size=1024, hidden_dim=256, contrastive=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.net = nn.Sequential(
            nn.Linear(input_shape, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, goal_shape)
        )

        self.input_shape = input_shape
        self.goal_shape = goal_shape
        self.opt = torch.optim.Adam(self.net.parameters(), lr=1e-3)
        self.device = device
        self.goal_rep = goal_rep
        self.categoricals = categoricals
        self.latents = latents
        self.bsz = batch_size
        self.contrastive = contrastive

    def sample(self, batch_size=1):
        # Always use zeros as input
        zeros_input = torch.zeros(
            (batch_size, self.input_shape), device=self.device)
        logits = self.net(zeros_input).reshape(
            (-1, self.categoricals, self.latents))
        dist = torch.distributions.OneHotCategoricalStraightThrough(
            logits=logits)
        sample = dist.rsample()
        return sample.reshape((-1, self.goal_shape))

    def update(self, buffer, step):
        loss, loss_dict = self.compute_loss(buffer)
        if loss.requires_grad:
            self.opt.zero_grad(set_to_none=True)
            loss.backward()
            self.opt.step()
        return loss_dict

    def compute_loss(self, buffer):
        with torch.no_grad():
            data = buffer.sample(self.bsz, 1, to_device=self.device)
            if data is None:
                return torch.zeros((1,), device=self.device), {}

            rewards = data["reward"].reshape((-1, 1))
            mean_rewards = rewards.mean()
            initial_data = {}
            reached_data = {}

            for key, val in data.items():
                initial_data[key] = val[:, 0]
                reached_data[key] = val.reshape((-1, *val.shape[2:]))

            reached = self.goal_rep.encode(reached_data["observation"])

        # Hindsight energy update
        zeros_input = torch.zeros((1, self.input_shape), device=self.device)
        logits = self.net(zeros_input).reshape(
            (-1, self.categoricals, self.latents))
        dist = torch.distributions.OneHotCategoricalStraightThrough(
            logits=logits)
        prediction = dist.rsample().reshape((-1, self.categoricals*self.latents)
                                            ).repeat((self.bsz, 1)).reshape((-1, self.categoricals*self.latents))
        target = reached.reshape((-1, self.categoricals*self.latents))

        energy_f = self.goal_rep.subset_energy.energy(target, prediction)
        e = torch.nn.BCEWithLogitsLoss(reduction="none")(
            energy_f, torch.ones_like(energy_f)*rewards)
        if self.contrastive:
            loss = e.mean()
        else:
            loss = (rewards * e).mean()

        loss_dict = {"energy_loss": loss.detach(
        ), "mean_rewards": mean_rewards.detach()}

        return loss, loss_dict


def sample_goal(goal_abstraction, target, dir, steps=100, lr=0.1, latents=16, categories=16, device="cpu"):
    """Sample a goal more concrete (dir=1) or more abstract (dir=-1) than target"""
    with torch.enable_grad():
        logits = torch.randn((1, latents*categories),
                             device=device, requires_grad=True)
        optim = torch.optim.SGD(params=[logits], lr=lr)
        for i in range(steps):
            goal = torch.distributions.OneHotCategoricalStraightThrough(logits=logits.reshape(
                (-1, latents, categories))).rsample().reshape((-1, latents*categories))
            if dir == 1:
                # More concrete: sampled goal should be subset of target
                e = -goal_abstraction.subset_energy.energy(
                    goal.repeat(target.shape[0], 1), target.detach())
            else:
                # More abstract: target should be subset of sampled goal
                e = -goal_abstraction.subset_energy.energy(
                    target.detach(), goal.repeat(target.shape[0], 1))
            loss = e.mean()
            optim.zero_grad()
            loss.backward()
            optim.step()
        return goal.detach(), torch.sigmoid(-e).mean().item()
