import torch
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from .architectures import *
from utils import gaussian_params

logtwopi = math.log(2 * math.pi)


class GFN(nn.Module):

    def __init__(
        self,
        dim: int,
        s_emb_dim: int,
        hidden_dim: int,
        harmonics_dim: int,
        t_dim: int,
        log_var_range: float = 4.0,
        t_scale: float = 1.0,
        langevin: bool = False,
        learned_variance: bool = True,
        partial_energy: bool = False,
        clipping: bool = False,
        lgv_clip: float = 1e2,
        gfn_clip: float = 1e4,
        pb_scale_range: float = 0.9,
        pb_scale_policy: str = "const",
        langevin_scaling_per_dimension: bool = True,
        conditional_flow_model: bool = False,
        learn_pb: bool = False,
        process_param: str = "standard",
        pis_architectures: bool = False,
        share_backbone: bool = False,
        lgv_layers: int = 3,
        joint_layers: int = 2,
        zero_init: bool = False,
        device=torch.device("cuda"),
    ):
        super(GFN, self).__init__()
        self.dim = dim
        self.harmonics_dim = harmonics_dim
        self.t_dim = t_dim
        self.s_emb_dim = s_emb_dim

        self.langevin = langevin
        self.learned_variance = learned_variance
        self.partial_energy = partial_energy
        self.t_scale = t_scale

        self.clipping = clipping
        self.lgv_clip = lgv_clip
        self.gfn_clip = gfn_clip

        self.langevin_scaling_per_dimension = langevin_scaling_per_dimension
        self.conditional_flow_model = conditional_flow_model
        self.learn_pb = learn_pb
        self.process_param = process_param

        self.pis_architectures = pis_architectures
        self.lgv_layers = lgv_layers
        self.joint_layers = joint_layers

        self.pf_std_per_traj = np.sqrt(self.t_scale)
        self.log_var_range = log_var_range

        self.device = device

        self.pb_scale_range = pb_scale_range
        self.pb_scale_policy = pb_scale_policy
        self.current_it = 25000

        if self.pis_architectures:

            self.t_model = TimeEncodingPIS(harmonics_dim, t_dim, hidden_dim)
            self.s_model = StateEncodingPIS(dim, hidden_dim, s_emb_dim)

            if share_backbone:
                shared_backbone = [nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(joint_layers)]
            else:
                shared_backbone = None

            self.joint_model = JointPolicyPIS(dim, s_emb_dim, t_dim, hidden_dim, 2 * dim, joint_layers, shared_backbone, zero_init)
            if learn_pb:
                self.back_model = JointPolicyPIS(dim, s_emb_dim, t_dim, hidden_dim, 2 * dim, joint_layers, shared_backbone, zero_init)

            if self.conditional_flow_model:
                self.flow_model = FlowModelPIS(dim, s_emb_dim, t_dim, hidden_dim, 1, joint_layers)
            else:
                self.flow_model = torch.nn.Parameter(torch.tensor(0.0).to(self.device))

            if self.langevin_scaling_per_dimension:
                self.langevin_scaling_model = LangevinScalingModelPIS(s_emb_dim, t_dim, hidden_dim, dim, lgv_layers, zero_init)
            else:
                self.langevin_scaling_model = LangevinScalingModelPIS(s_emb_dim, t_dim, hidden_dim, 1, lgv_layers, zero_init)

        else:

            self.t_model = TimeEncoding(harmonics_dim, t_dim, hidden_dim)
            self.s_model = StateEncoding(dim, hidden_dim, s_emb_dim)

            if share_backbone:
                shared_backbone = nn.Sequential(
                    nn.Linear(s_emb_dim + t_dim, hidden_dim),
                    nn.GELU(),
                    *[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(joint_layers - 1)],
                )
            else:
                shared_backbone = None
            self.joint_model = JointPolicy(dim, s_emb_dim, t_dim, hidden_dim, 2 * dim, shared_backbone, zero_init)
            if learn_pb:
                self.back_model = JointPolicy(dim, s_emb_dim, t_dim, hidden_dim, 2 * dim, shared_backbone, zero_init)

            if self.conditional_flow_model:
                self.flow_model = FlowModel(s_emb_dim, t_dim, hidden_dim, 1)
            else:
                self.flow_model = torch.nn.Parameter(torch.tensor(0.0).to(self.device))

            if self.langevin_scaling_per_dimension:
                self.langevin_scaling_model = LangevinScalingModel(s_emb_dim, t_dim, hidden_dim, dim, zero_init)
            else:
                self.langevin_scaling_model = LangevinScalingModel(s_emb_dim, t_dim, hidden_dim, 1, zero_init)

    def split_params(self, tensor, batch_jacobians=None):
        mean, raw_var = gaussian_params(tensor)
        if not self.learned_variance:
            logvar = torch.zeros_like(raw_var) + np.log(self.pf_std_per_traj) * 2.0
            return mean, logvar
        else:
            if self.process_param == "standard":
                logvar = torch.tanh(raw_var) * self.log_var_range + np.log(self.pf_std_per_traj) * 2.0
                return mean, logvar
            else:
                raise ValueError(f"Invalid process_param: {self.process_param}")

    def get_pb_scale_range(self):
        if self.pb_scale_policy == "const":
            return self.pb_scale_range
        elif self.pb_scale_policy == "linear":
            return self.pb_scale_range * min(self.current_it / 15000, 1)
        else:
            raise ValueError(f"Invalid pb_scale_policy: {self.pb_scale_policy}")

    def predict_next_state(self, s, t, log_r):
        if self.langevin:
            s.requires_grad_(True)
            with torch.enable_grad():
                grad_log_r = torch.autograd.grad(log_r(s).sum(), s)[0].detach()
                grad_log_r = torch.nan_to_num(grad_log_r)
                if self.clipping:
                    grad_log_r = torch.clip(grad_log_r, -self.lgv_clip, self.lgv_clip)

        # print(f"1. {s.shape=} {t.shape=}")

        if s.dim() == 1:
            assert t.dim() == 0
            s = s.unsqueeze(0)
            t = t.unsqueeze(0)

        # print(f"2. {s.shape=} {t.shape=}")

        bsz = s.shape[0]

        t_lgv = t

        t = self.t_model(t)
        s = self.s_model(s)
        s_new_info = self.joint_model(s, t)

        flow = self.flow_model(s, t).squeeze(-1) if self.conditional_flow_model or self.partial_energy else self.flow_model

        if self.langevin:
            if self.pis_architectures:
                scale = self.langevin_scaling_model(t_lgv)
            else:
                scale = self.langevin_scaling_model(s, t)
            s_new_info[..., : self.dim] += scale * grad_log_r

        if self.clipping:
            s_new_info = torch.clip(s_new_info, -self.gfn_clip, self.gfn_clip)
        return s_new_info, flow.squeeze(-1)

    def get_trajectory_fwd(self, s, discretizer_fn, exploration_std, log_r, states=None, pis=False, return_min_max=False):
        if s is None:
            bsz = states.shape[0]
        else:
            bsz = s.shape[0]

        ts = discretizer_fn(bsz).to(self.device)
        # print(f"{ts=}")
        # print(f"{states=}")
        trajectory_length = ts.shape[1] - 1

        if states is None:
            sample_trajectory = True
            states = torch.zeros((bsz, trajectory_length + 1, self.dim), device=self.device)
        else:
            assert states.shape == (bsz, trajectory_length + 1, self.dim)
            sample_trajectory = False
            s = states[:, 0]

        logpf = torch.zeros((bsz, trajectory_length), device=self.device)
        logpb = torch.zeros((bsz, trajectory_length), device=self.device)
        logf = torch.zeros((bsz, trajectory_length + 1), device=self.device)
        back_mean_corrections = torch.ones((bsz, trajectory_length + 1, self.dim), device=self.device)
        back_var_corrections = torch.ones((bsz, trajectory_length + 1, self.dim), device=self.device)
        all_pf_mean = torch.zeros((bsz, trajectory_length + 1, self.dim), device=self.device)  # 1
        all_pfvars = torch.zeros((bsz, trajectory_length + 1, self.dim), device=self.device)  # 2

        max_speed = -torch.inf
        min_logvar, max_logvar = torch.inf, -torch.inf

        for i in range(trajectory_length):
            dts = ts[:, i + 1] - ts[:, i]

            pfs, flow = self.predict_next_state(s, ts[:, i], log_r)

            logf[:, i] = flow
            if self.partial_energy:
                ref_log_var = (self.t_scale * ts[:, max(1, i)]).log()
                log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s**2)).sum(1)
                logf[:, i] += (1 - ts[:, i]) * log_p_ref + ts[:, i] * log_r(s)

            # print(f"{dts=}")
            # print(f"{dts.shape=} {pflogvars.shape=}")
            # print(f"{torch.mean(dts.unsqueeze(1) * pf_mean)=}")

            if self.process_param == "standard":
                pf_mean, pflogvars = self.split_params(pfs)
                all_pfvars[:, i] = dts.sqrt().unsqueeze(1) * (pflogvars / 2).exp()
                all_pf_mean[:, i] = dts.unsqueeze(1) * pf_mean

                if exploration_std is None:
                    pflogvars_sample = pflogvars if pis else pflogvars.detach()
                else:
                    # currently not using this arg -- could use ts here, would need changes to utils get_exploration_std
                    expl = exploration_std(None)
                    if expl <= 0.0:
                        pflogvars_sample = pflogvars.detach()
                    else:
                        add_log_var = torch.full_like(pflogvars, np.log(exploration_std(i)) * 2) / dts.sqrt().unsqueeze(1)
                        pflogvars_sample = pflogvars if pis else torch.logaddexp(pflogvars.detach(), add_log_var)

                if sample_trajectory:
                    s_ = (
                        s
                        + dts.unsqueeze(1) * (pf_mean if pis else pf_mean.detach())
                        + dts.sqrt().unsqueeze(1) * (pflogvars_sample / 2).exp() * torch.randn_like(s, device=self.device)
                    )
                else:
                    s_ = states[:, i + 1]

                noise = ((s_ - s) - dts.unsqueeze(1) * pf_mean) / (dts.sqrt().unsqueeze(1) * (pflogvars / 2).exp())
                logpf[:, i] = -0.5 * (noise**2 + logtwopi + dts.log().unsqueeze(1) + pflogvars).sum(1)

            else:
                raise ValueError(f"Invalid process_param: {self.process_param}")

            # min_logvar, max_logvar = min(min_logvar, torch.min(pflogvars)), max(max_logvar, torch.max(pflogvars))
            # max_speed = max(max_speed, dts * (pf_mean[:, 0] ** 2 + pf_mean[:, 1] ** 2).sqrt().max())

            if self.learn_pb:
                pbs = self.back_model(self.s_model(s_), self.t_model(ts[:, i + 1]))
                dmean, dvar = gaussian_params(pbs)
                back_mean_correction = 1 + torch.tanh(dmean) * self.get_pb_scale_range()
                if self.process_param in ["standard"]:
                    back_var_correction = 1 + torch.tanh(dvar) * self.get_pb_scale_range()
                elif self.process_param == "log_variance":
                    back_var_correction = torch.exp(torch.tanh(dvar) * self.log_var_range)
            else:
                back_mean_correction, back_var_correction = torch.ones_like(s_), torch.ones_like(s_)

            if i > 0:
                back_mean = s_ - s_ * (dts / ts[:, i + 1]).unsqueeze(1) * back_mean_correction
                back_var = (self.pf_std_per_traj**2) * (dts * ts[:, i] / ts[:, i + 1]).unsqueeze(1) * back_var_correction
                noise_backward = (s - back_mean) / back_var.sqrt()
                logpb[:, i] = -0.5 * (noise_backward**2 + logtwopi + back_var.log()).sum(1)

            # back_mean_corrections[:, i + 1] = back_mean_correction
            # back_var_corrections[:, i + 1] = back_var_correction

            s = s_
            if sample_trajectory:
                states[:, i + 1] = s

        if return_min_max:
            return (
                states,
                logpf,
                logpb,
                logf,
                all_pf_mean,
                all_pfvars,
                back_mean_corrections,
                back_var_corrections,
                min_logvar,
                max_logvar,
            )
        return states, logpf, logpb, logf

    def get_trajectory_bwd(self, s, discretizer_fn, log_r, return_min_max=False):
        bsz = s.shape[0]

        ts = discretizer_fn(bsz).to(self.device)
        trajectory_length = ts.shape[1] - 1

        logpf = torch.zeros((bsz, trajectory_length), device=self.device)
        logpb = torch.zeros((bsz, trajectory_length), device=self.device)
        logf = torch.zeros((bsz, trajectory_length + 1), device=self.device)
        states = torch.zeros((bsz, trajectory_length + 1, self.dim), device=self.device)
        states[:, -1] = s
        back_mean_corrections = torch.zeros((bsz, trajectory_length, self.dim), device=self.device)
        back_var_corrections = torch.zeros((bsz, trajectory_length, self.dim), device=self.device)
        all_pb_mean = torch.zeros((bsz, trajectory_length + 1, self.dim), device=self.device)  # 1
        all_pbstds = torch.zeros((bsz, trajectory_length + 1, self.dim), device=self.device)  # 4

        min_logvar, max_logvar = torch.inf, -torch.inf

        for i in range(trajectory_length):
            dts = ts[:, trajectory_length - i] - ts[:, trajectory_length - i - 1]

            if i < trajectory_length - 1:
                if self.learn_pb:
                    pbs = self.back_model(self.s_model(s), self.t_model(ts[:, trajectory_length - i]))
                    dmean, dvar = gaussian_params(pbs)
                    back_mean_correction = 1 + torch.tanh(dmean) * self.get_pb_scale_range()
                    if self.process_param == "standard":
                        back_var_correction = 1 + torch.tanh(dvar) * self.get_pb_scale_range()
                    elif self.process_param == "log_variance":
                        back_var_correction = torch.exp(torch.tanh(dvar) * self.log_var_range)
                else:
                    back_mean_correction, back_var_correction = torch.ones_like(s), torch.ones_like(s)

                mean = s - s * (dts / ts[:, trajectory_length - i]).unsqueeze(1) * back_mean_correction
                var = (
                    (self.pf_std_per_traj**2)
                    * (dts * ts[:, trajectory_length - i - 1] / ts[:, trajectory_length - i]).unsqueeze(1)
                    * back_var_correction
                )
                s_ = mean.detach() + var.sqrt().detach() * torch.randn_like(s, device=self.device)
                noise_backward = (s_ - mean) / var.sqrt()
                logpb[:, trajectory_length - i - 1] = -0.5 * (noise_backward**2 + logtwopi + var.log()).sum(1)

                min_logvar, max_logvar = min(min_logvar, var.log().min()), max(max_logvar, var.log().max())

                all_pb_mean[:, trajectory_length - i, :] = mean - s
                all_pbstds[:, trajectory_length - i, :] = var.sqrt().detach()
            else:
                s_ = torch.zeros_like(s)

            pfs, flow = self.predict_next_state(s_, ts[:, trajectory_length - i - 1], log_r)

            logf[:, trajectory_length - i - 1] = flow
            if self.partial_energy:
                ref_log_var = (self.t_scale * ts[:, max(1, trajectory_length - i - 1)]).log()
                log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s**2)).sum(1)
                logf[:, trajectory_length - i - 1] += ts[:, trajectory_length - i - 1] * log_p_ref + ts[:, i + 1] * log_r(s)

            if self.process_param == "standard":
                pf_mean, pflogvars = self.split_params(pfs)

                min_logvar, max_logvar = min(min_logvar, pflogvars.min()), max(max_logvar, pflogvars.max())

                noise = ((s - s_) - dts.unsqueeze(1) * pf_mean) / (dts.sqrt().unsqueeze(1) * (pflogvars / 2).exp())
                logpf[:, trajectory_length - i - 1] = -0.5 * (noise**2 + logtwopi + dts.log().unsqueeze(1) + pflogvars).sum(1)

            if trajectory_length - i - 1 > 0:
                back_mean_corrections[:, trajectory_length - i - 1] = back_mean_correction
                back_var_corrections[:, trajectory_length - i - 1] = back_var_correction

            s = s_
            states[:, trajectory_length - i - 1] = s_

        if return_min_max:
            return (
                states,
                logpf,
                logpb,
                logf,
                all_pb_mean,
                all_pbstds,
                back_mean_corrections,
                back_var_corrections,
                min_logvar,
                max_logvar,
            )
        return states, logpf, logpb, logf

    def sample(self, batch_size, discretizer_fn, log_r):
        s = torch.zeros(batch_size, self.dim).to(self.device)
        return self.get_trajectory_fwd(s, discretizer_fn, None, log_r)[0][:, -1]

    def sleep_phase_sample(self, batch_size, discretizer_fn, exploration_std):
        s = torch.zeros(batch_size, self.dim).to(self.device)
        return self.get_trajectory_fwd(s, discretizer_fn, exploration_std, log_r=None)[0][:, -1]

    def forward(self, s, discretizer_fn, exploration_std, log_r, states=None, pis=False):
        return self.get_trajectory_fwd(s, discretizer_fn, exploration_std, log_r, states, pis)
