import os
import pickle
import numpy as np
import torch
from torch.optim import Adam
from nns import logF

from gfn_base import GFlowNetBase
from torch.nn.utils.rnn import pad_sequence

# ContributionsTensor = TT["max_len * (1 + max_len) / 2", "n_trajectories"]
# CumulativeLogProbsTensor = TT["max_length + 1", "n_trajectories"]
# LogStateFlowsTensor = TT["max_length", "n_trajectories"]
# LogTrajectoriesTensor = TT["max_length", "n_trajectories", torch.float]
# MaskTensor = TT["max_length", "n_trajectories"]
# PredictionsTensor = TT["max_length + 1 - i", "n_trajectories"]
# TargetsTensor = TT["max_length + 1 - i", "n_trajectories"]

# GFN sub trajectory from torchgfn: https://github.com/GFNOrg/torchgfn/blob/master/src/gfn/gflownet/sub_trajectory_balance.py
# modified to adapt to the gymnasium environment
class SubTBGFlowNet(GFlowNetBase):
    r"""GFlowNet for the Sub Trajectory Balance Loss.

    This method is described in [Learning GFlowNets from partial episodes
    for improved convergence and stability](https://arxiv.org/abs/2209.12782).

    Attributes:
        logF: a LogStateFlowEstimator instance.
        weighting: sub-trajectories weighting scheme.
            - "DB": Considers all one-step transitions of each trajectory in the
                batch and weighs them equally (regardless of the length of
                trajectory). Should be equivalent to DetailedBalance loss.
            - "ModifiedDB": Considers all one-step transitions of each trajectory
                in the batch and weighs them inversely proportional to the
                trajectory length. This ensures that the loss is not dominated by
                long trajectories. Each trajectory contributes equally to the loss.
            - "TB": Considers only the full trajectory. Should be equivalent to
                TrajectoryBalance loss.
            - "equal_within": Each sub-trajectory of each trajectory is weighed
                equally within the trajectory. Then each trajectory is weighed
                equally within the batch.
            - "equal": Each sub-trajectory of each trajectory is weighed equally
                within the set of all sub-trajectories.
            - "geometricwithin": Each sub-trajectory of each trajectory is weighed
                proportionally to (lamda ** len(sub_trajectory)), within each
                trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.
            - "geometric": Each sub-trajectory of each trajectory is weighed
                proportionally to (lamda ** len(sub_trajectory)), within the set of
                all sub-trajectories.
        lamda: discount factor for longer trajectories.
    """

    def __init__(
        self, env, learning_rate = 1e-3, batch_size = 32, buffer_size = 10000,\
        train_freq=16, gradient_steps = 10, learning_starts = 100, \
        temperature = 1,\
        sample_method = 0, 
        use_filter = False, \
        weighting = "geometricwithin", \
        lamda: float = 0.9, \
        device = 'auto', continuous = True, tensorboard_log = None, verbose = False,\
        hidden_sizes = [256, 256], \
        activation_fn = torch.nn.ReLU,\
        initial_z = 0.0,\
        num_val_samples=0,\
        pessimistic_updates = 0, \
        model_dir=None, \
        validation_env=None, \
        data_env=None, \
        no_decay = False,\
        timeout_mask = False,\
        filter_upper = 3,\
        filter_lower = 2,\
        epsilon_random = 0.1):

        super().__init__(
             env, learning_rate, batch_size, buffer_size, \
                train_freq, gradient_steps, learning_starts, \
                None, 100, \
                temperature, sample_method, \
                use_filter, \
                device, continuous, tensorboard_log, verbose, \
                hidden_sizes, activation_fn, \
                initial_z, num_val_samples, \
                pessimistic_updates, \
                model_dir, validation_env, data_env, no_decay, \
                timeout_mask, filter_upper, filter_lower, epsilon_random)
        
        if self.device == 'auto':
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.state_dim = env.observation_space.shape[0]
        self.logF = logF(self.state_dim, hidden_sizes, activation_fn, device = self.device).to(self.device)
        
        self.weighting = weighting
        self.lamda = lamda

        self.logF_optim = Adam(self.logF.parameters(), lr=self.learning_rate)
        self.logF_optim.param_groups[0]['initial_lr'] = self.learning_rate
        self.logF_optim.param_groups[0]['min_lr'] = self.learning_rate/100

    def train(self):
        if self.sample_method == 1:
            self.memory.update_threshold(self.batch_size)
        traj_losses = 0
        traj_losses_std = 0
        for gradient_step in range(self.gradient_steps):
            if self.sample_method == 1:
                    batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.biased_sample(self.batch_size)
            elif self.sample_method == 2:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.generalized_biased_sample(self.batch_size)
            elif self.sample_method == 3:
                    batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.mixed_priority_sample(self.batch_size)
            else:
                    batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.sample(self.batch_size)
                    
            batch_obs_pad = pad_sequence(batch_obs, batch_first = True)
            batch_acts_pad = pad_sequence(batch_acts, batch_first = True)
            batch_next_obs_pad = pad_sequence(batch_next_obs, batch_first = True)
            
            lengths = torch.tensor(np.array([len(obs) for obs in batch_obs]), device = self.device)

            # a different PB update described in the paper
            for _ in range(self.pessimistic_updates):
                if self.continuous:
                        # Backward pass
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
                else:
                        backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)
                
                logPB = torch.sum(logPB, dim = 1)
                pessimistic_loss = -logPB.mean()
                self.backward_optim.zero_grad()
                pessimistic_loss.backward()
                self.backward_optim.step()

            batch_rews = torch.stack(batch_rews)
            batch_augmented_rews = torch.stack(batch_augmented_rews)

            log_rewards = torch.log(batch_rews + self.temperature * batch_augmented_rews).flatten()

            assert log_rewards is not None

            # Calculate the log probabilities of the forward and backward policies
            if self.continuous:
                # Forward pass
                log_pf_trajectories = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, self.env.unwrapped.max_t, lengths = lengths, use_mask = self.timeout_mask)
                # Backward pass
                log_pb_trajectories = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
            else:
                forward_mask = self.env.unwrapped.get_forward_action_masks(batch_obs_pad)
                backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                log_pf_trajectories = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, forward_mask, lengths = lengths)
                log_pb_trajectories = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

            # for batch filtering 
            logPF = torch.sum(log_pf_trajectories, dim = 1)
            logPB = torch.sum(log_pb_trajectories, dim = 1)
            
            # switch the first and second dimensions
            log_pf_trajectories = log_pf_trajectories.permute(1, 0) # shape: (max_len, n_trajectories)
            log_pb_trajectories = log_pb_trajectories.permute(1, 0)

            log_pf_trajectories_cum = self.cumulative_logprobs(
                len(batch_idx), log_pf_trajectories
            ) # shape: (max_len + 1, n_trajectories)

            log_pb_trajectories_cum = self.cumulative_logprobs(
                len(batch_idx), log_pb_trajectories
            )

            log_state_flows = self.calculate_log_state_flows(
                batch_obs_pad, lengths, log_pf_trajectories
            )

            full_mask, sink_states_mask, is_terminal_mask = self.calculate_masks(
                batch_obs_pad, lengths
            )

            flattening_masks = []
            scores = []

            for i in range(1, 1 + lengths.max().item()):
                preds = self.calculate_preds(
                    log_pf_trajectories_cum, log_state_flows, i
                )

                targets = self.calculate_targets(
                    log_rewards,
                    preds,
                    log_pb_trajectories_cum,
                    log_state_flows,
                    is_terminal_mask,
                    sink_states_mask,
                    full_mask,
                    i,
                )

                flattening_mask = lengths.lt(torch.arange(
                    i,
                    lengths.max().item() + 1,
                    device=lengths.device,
                ).unsqueeze(-1))

                flat_preds = preds[~flattening_mask]
                if torch.any(torch.isnan(flat_preds)):
                    raise ValueError("NaN in preds")

                flat_targets = targets[~flattening_mask]
                if torch.any(torch.isnan(flat_targets)):
                    raise ValueError("NaN in targets")

                flattening_masks.append(flattening_mask)
                scores.append(preds - targets)
            
            flattening_mask = torch.cat(flattening_masks)
            all_scores = torch.cat(scores, 0)

            if self.weighting == "DB": 
                loss = scores[0][~flattening_masks[0]].pow(2).mean()

            elif self.weighting == "geometric":
                per_length_losses = torch.stack(
                    [
                        scores[~flattening_mask].pow(2).mean()
                        for scores, flattening_mask in zip(scores, flattening_masks)
                    ]
                )
                max_len = lengths.max().item()
                L = self.lamda
                ratio = (1 - L) / (1 - L**max_len)
                weights = ratio * (
                    L ** torch.arange(max_len, device=per_length_losses.device)
                )
                assert (weights.sum() - 1.0).abs() < 1e-5, f"{weights.sum()}"
                loss = (per_length_losses * weights).sum()
            else:
                if self.weighting == "equal_within":
                    contributions = self.get_equal_within_contributions(lengths)

                elif self.weighting == "equal":
                    contributions = self.get_equal_contributions(lengths)

                elif self.weighting == "TB":
                    contributions = self.get_tb_contributions(lengths, all_scores)

                elif self.weighting == "ModifiedDB":
                    contributions = self.get_modified_db_contributions(lengths)

                elif self.weighting == "geometricwithin":
                    contributions = self.get_geometric_within_contributions(lengths)

                else:
                    raise ValueError(f"Unknown weighting method {self.weighting}")

                flat_contributions = contributions[~flattening_mask]
                assert (
                    flat_contributions.sum() - 1.0
                ).abs() < 1e-5, f"{flat_contributions.sum()}"
                losses = flat_contributions * all_scores[~flattening_mask].pow(2)
                   
                # aggregate losses to batch level, losses has shape [(max_len * (1 + max_len) / 2) * n_trajectories]
                # need to make it has shape [n_trajectories]
                max_len = lengths.max().item()

                loss = torch.zeros_like(flattening_mask, dtype=torch.float)
                loss[~flattening_mask] = losses
                loss = loss.sum(dim=0)
                traj_losses += loss.mean().detach().item()
                traj_losses_std += loss.std().detach().item()

                # if self.sample_method == 2: 
                #     self.memory.update_priority(batch_idx, np.ones(len(batch_idx)))
                self.memory.push_train_logs(logPF.detach().cpu().numpy(), logPB.detach().cpu().numpy(), loss.detach().cpu().numpy(), batch_rews.cpu().numpy(), self.logZ.item())

                if self.sample_method >= 2:
                    self.memory.update_priority(batch_idx, np.abs(loss.detach().cpu().numpy()))
               
                if self.use_filter:
                    # rew_filter = batch_rews.squeeze() > batch_rews.mean()
                    batch_norm = - logPF.detach() - logPB.detach() + log_rewards.detach()

                    # if sum(rew_filter) > 1:
                    filter_norm = batch_norm #[rew_filter]
                    batch_mean = filter_norm.mean()
                    batch_std = filter_norm.std()

                    batch_filter1 = (batch_norm > (batch_mean + self.filter_upper * batch_std)) #& rew_filter # when this to be true, the entire trajectory, though with high reward, has very low logPF
                
                    if(torch.any(batch_filter1) and self.verbose):
                        print("Will promote:")
                        print("reward:", batch_rews[batch_filter1])
                        print("logPF:", logPF[batch_filter1])
                        print("logPB:", logPB[batch_filter1])
                        print("This", batch_norm[batch_filter1], "Mean:", batch_mean, "Std:", batch_std)

                    # maximize the PF, PB 
                    loss = loss - (logPF + logPB) * (batch_filter1)  / self.batch_size

                    batch_filter2 = (batch_norm < (batch_mean - self.filter_lower * batch_std)) #& rew_filter # when this to be true, this is a relatively old trajectory with high logPF and logPB, but low reward, this is a sign of overfitting
                    
                    if(torch.any(batch_filter2) and self.verbose):
                        print("Will depress:")
                        print("reward:", batch_rews[batch_filter2])
                        print("logPF:", logPF[batch_filter2])
                        print("logPB:", logPB[batch_filter2])
                        print("This", batch_norm[batch_filter2], "Mean:", batch_mean, "Std:", batch_std)

                    loss = loss * (~batch_filter2) - (logPB - logPF) * (batch_filter2) / self.batch_size # release the PF (IMO this is the exploration budget) and update the PB by the maximum likelihood
                    
                loss = loss.sum()

            self.forward_optim.zero_grad()
            self.backward_optim.zero_grad()
            self.logF_optim.zero_grad()

            loss.backward()

            # torch.nn.utils.clip_grad_norm_(self.forward_policy.parameters(), 1e1)
            # torch.nn.utils.clip_grad_norm_(self.backward_policy.parameters(), 1e1)
            # torch.nn.utils.clip_grad_norm_(self.logF, 1e1)

            self.forward_optim.step()
            self.backward_optim.step()
            self.logF_optim.step()

            self.logger['log_Z'].append(self.logZ.item())

        self._n_updates += gradient_step
        self.logger['traj_losses'].append(traj_losses/self.gradient_steps)
        self.logger['traj_losses_std'].append(traj_losses_std/self.gradient_steps)

    def cumulative_logprobs(
        self,
        n_trajectories,
        log_p_trajectories,
    ):
        """Calculates the cumulative log probabilities for all trajectories.

        Args:
            trajectories: a batch of trajectories.
            log_p_trajectories: log probabilities of each transition in each trajectory.

        Returns: cumulative sum of log probabilities of each trajectory.
        """

        return torch.cat(
            (
                torch.zeros(
                    1, n_trajectories, device=log_p_trajectories.device
                ),
                log_p_trajectories.cumsum(dim=0),
            ),
            dim=0,
        )

    def calculate_preds(
        self,
        log_pf_trajectories_cum,
        log_state_flows,
        i: int,
    ):
        """
        Calculate the predictions tensor for the current sub-trajectory length.
        """
        current_log_state_flows = (
            log_state_flows if i == 1 else log_state_flows[: -(i - 1)]
        )

        logPF = log_pf_trajectories_cum[i:] - log_pf_trajectories_cum[:-i]

        preds = (
            logPF
            + current_log_state_flows
        )

        return preds

    def calculate_targets(
        self,
        log_rewards,
        preds,
        log_pb_trajectories_cum,
        log_state_flows,
        is_terminal_mask,
        sink_states_mask,
        full_mask,
        i: int,
    ):
        """
        Calculate the targets tensor for the current sub-trajectory length.
        """
        targets = torch.full_like(preds, fill_value=-float("inf"))

        terminal_mask = is_terminal_mask[i - 1 :]

        # The following creates the targets for the non-finishing sub-trajectories
        targets.T[terminal_mask.T] = log_rewards[terminal_mask.sum(dim = 0) > 0]

        # For now, the targets contain the log-rewards of the ending sub trajectories
        # We need to add to that the log-probabilities of the backward actions up-to
        # the sub-trajectory's terminating state
        if i > 1:
            targets[terminal_mask] += (
                log_pb_trajectories_cum[i - 1 :] - log_pb_trajectories_cum[: -i + 1]
            )[:-1][terminal_mask]

        targets[~full_mask[i - 1 :]] = (
            log_pb_trajectories_cum[i:] - log_pb_trajectories_cum[:-i]
        )[:-1][~full_mask[i - 1 : -1]] + log_state_flows[i:][~sink_states_mask[i:]]

        return targets

    def calculate_log_state_flows(
        self,
        batch_obs_pad, 
        lengths, 
        log_pf_trajectories
    ):
        """
        Calculate log state flows and masks for sink and terminal states.

        Args:
            trajectories: The trajectories data.
            env: The environment object.

        Returns:
            log_state_flows: Log state flows.
            full_mask: A boolean tensor representing full states.
        """
        states = batch_obs_pad.clone()
        log_state_flows = torch.full_like(log_pf_trajectories, fill_value=-float("inf"))
        # if state < length, then it is not a sink state
        mask = torch.arange(states.size(1), device=states.device) < lengths.unsqueeze(-1)

        mask = mask.permute(1, 0)
        states = states.permute(1, 0, 2)

        valid_states = states[mask]

        log_F = self.logF(valid_states).squeeze(-1)

        log_state_flows[mask] = log_F

        return log_state_flows

    def calculate_masks(
        self,
        batch_obs_pad,
        lengths
    ):
        """
        Calculate masks for sink and terminal states.
        """
        states = batch_obs_pad
        sink_states_mask = torch.arange(states.size(1), device = states.device) >= lengths.unsqueeze(-1)
        is_terminal_mask = torch.arange(states.size(1), device = states.device) == lengths.unsqueeze(-1) - 1
        sink_states_mask = sink_states_mask.permute(1, 0)
        is_terminal_mask = is_terminal_mask.permute(1, 0)
        full_mask = sink_states_mask | is_terminal_mask
        return full_mask, sink_states_mask, is_terminal_mask

    def get_equal_within_contributions(
        self, lengths
    ):
        """
        Calculates contributions for the 'equal_within' weighting method.
        """
        is_done = lengths
        max_len = lengths.max().item()
        n_rows = int(max_len * (1 + max_len) / 2)

        # the following tensor represents the inverse of how many sub-trajectories there are in each trajectory
        contributions = 2.0 / (is_done * (is_done + 1)) / len(lengths)

        # if we repeat the previous tensor, we get a tensor of shape
        # (max_len * (max_len + 1) / 2, n_trajectories) that we can multiply with
        # all_scores to get a loss where each sub-trajectory is weighted equally
        # within each trajectory.
        contributions = contributions.repeat(n_rows, 1)

        return contributions

    def get_equal_contributions(
        self, lengths
    ):
        """
        Calculates contributions for the 'equal' weighting method.
        """
        is_done = lengths
        max_len = lengths.max().item()
        n_rows = int(max_len * (1 + max_len) / 2)
        n_sub_trajectories = int((is_done * (is_done + 1) / 2).sum().item())
        contributions = torch.ones(n_rows, len(lengths), device = self.device) / n_sub_trajectories
        return contributions

    def get_tb_contributions(
        self, lengths, all_scores
    ):
        """
        Calculates contributions for the 'TB' weighting method.
        """
        max_len = lengths
        is_done = lengths.max().item()

        # Each trajectory contributes one element to the loss, equally weighted
        contributions = torch.zeros_like(all_scores, device = self.device)
        indices = (max_len * (is_done - 1) - (is_done - 1) * (is_done - 2) / 2).long()
        contributions.scatter_(0, indices.unsqueeze(0), 1)
        contributions = contributions / len(lengths)

        return contributions

    def get_modified_db_contributions(
        self, lengths
    ):
        """
        Calculates contributions for the 'ModifiedDB' weighting method.
        """
        is_done = lengths
        max_len = lengths.max().item()
        n_rows = int(max_len * (1 + max_len) / 2)

        # The following tensor represents the inverse of how many transitions
        # there are in each trajectory.
        contributions = (1.0 / is_done / len(lengths)).repeat(max_len, 1)
        contributions = torch.cat(
            (
                contributions,
                torch.zeros(
                    (n_rows - max_len, len(lengths)),
                    device=contributions.device,
                ),
            ),
            0,
        )
        return contributions

    def get_geometric_within_contributions(
        self, lengths
    ):
        """
        Calculates contributions for the 'geometric_within' weighting method.
        """
        L = self.lamda
        max_len = lengths.max().item()
        is_done = lengths

        # The following tensor represents the weights given to each possible
        # sub-trajectory length.
        contributions = (L ** torch.arange(max_len, device = self.device).double()).float()

        contributions = contributions.unsqueeze(-1).repeat(1, len(lengths))

        contributions = contributions.repeat_interleave(
            torch.arange(max_len, 0, -1, device = self.device),
            dim=0,
            output_size=int(max_len * (max_len + 1) / 2),
        )

        # Now we need to divide each column by n + (n-1) lambda +...+ 1*lambda^{n-1}
        # where n is the length of the trajectory corresponding to that column
        # We can do it the ugly way, or using the cool identity:
        # https://www.wolframalpha.com/input?i=sum%28%28n-i%29+*+lambda+%5Ei%2C+i%3D0..n%29
        per_trajectory_denom = (
            1.0
            / (1 - L) ** 2
            * (L * (L ** is_done.double() - 1) + (1 - L) * is_done.double())
        ).float()

        contributions = contributions / per_trajectory_denom / len(lengths)

        return contributions

    def save(self, model_dir):
        if not os.path.exists(model_dir):
                os.makedirs(model_dir)
        torch.save(self.forward_policy.state_dict(), f'{model_dir}/forward_policy.pth')
        torch.save(self.backward_policy.state_dict(), f'{model_dir}/backward_policy.pth')
        torch.save(self.logF.state_dict(), f'{model_dir}/logF.pth')
        # save the optimizer
        torch.save(self.forward_optim.state_dict(), f'{model_dir}/forward_optim.pth')
        torch.save(self.backward_optim.state_dict(), f'{model_dir}/backward_optim.pth')
        torch.save(self.logF_optim.state_dict(), f'{model_dir}/logF_optim.pth')
        # save the i_so_far, e_so_far, t_so_far
        current_progress = (self.logger['t_so_far'], self.logger['i_so_far'], self.logger['e_so_far'])
        with open(f'{model_dir}/progress.pkl', 'wb') as f:
                pickle.dump(current_progress, f)

            
    def save_replay_buffer(self, model_dir):
        # Save our model and memory at the end of training
        self.memory.save(model_dir)

    def load(self, model_dir, load_optim = False):
        # Load our model and memory at the end of training
        self.forward_policy.load_state_dict(torch.load(f'{model_dir}/forward_policy.pth'))
        self.backward_policy.load_state_dict(torch.load(f'{model_dir}/backward_policy.pth'))
        self.logF.load_state_dict(torch.load(f'{model_dir}/logF.pth'))

        if load_optim:
                self.forward_optim.load_state_dict(torch.load(f'{model_dir}/forward_optim.pth'))
                self.backward_optim.load_state_dict(torch.load(f'{model_dir}/backward_optim.pth'))
                self.logF_optim.load_state_dict(torch.load(f'{model_dir}/logF_optim.pth'))

                # load the i_so_far, e_so_far, t_so_far
                with open(f'{model_dir}/progress.pkl', 'rb') as f:
                    t_so_far, i_so_far, e_so_far = pickle.load(f)
        
                return t_so_far, i_so_far, e_so_far
        return 0, 0, 0


    def load_replay_buffer(self, model_dir):
        # Load our model and memory at the end of training
        self.memory.load(model_dir)