import torch 
import torch.nn as nn 
import torch.nn.functional as F
import math 

from typing import List 
from copy import deepcopy 

from sal.utils import ForwardPolicyMeta
from sal.gflownet import GFlowNet 
from sal.experiments_utils import create_gfn, create_env, create_log_reward 

# code adapted from https://github.com/kylehkhsu/role-of-data

def inverse_softplus(x):
    return math.log(math.exp(x) - 1)

def kl_between_gaussians(p_mean, p_var, q_mean, q_var):
    """KL(p||q)"""
    # the code is fucking incorrect; a 2 was missing
    return 0.5 * ((q_var / p_var).log() + (p_var + (p_mean - q_mean).pow(2)).div(q_var) - 1)

def common_entries(*dcts): # get common entries of dictionaries 
    for i in set(dcts[0]).intersection(*dcts[1:]):
        yield (i,) + tuple(d[i] for d in dcts)

def var_of_rho(rho: torch.Tensor):
    # var(p) 
    return F.softplus(rho).pow(2)

def rho_of_var(var: torch.Tensor):
    # just inverse of the function above  
    return var.sqrt().exp().sub(1).log()

class MLP(nn.Module):
    def __init__(self, n_input, n_output, hidden_layer_sizes):
        super(MLP, self).__init__()
        n_in = n_input
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_layer_sizes)):
            layer = nn.Linear(n_in, hidden_layer_sizes[i], bias=True)
            with torch.no_grad():
                nn.init.normal_(layer.weight, mean=0, std=0.1).clamp_(min=-0.2, max=0.2)
                nn.init.normal_(layer.bias, mean=0, std=0.1).clamp_(min=-0.2, max=0.2)
            self.hidden_layers.append(layer)
            n_in = hidden_layer_sizes[i]
        self.output_layer = nn.Linear(n_in, n_output, bias=True)
        with torch.no_grad():
            nn.init.normal_(self.output_layer.weight, mean=0, std=0.1).clamp_(min=-0.2, max=0.2)
            nn.init.normal_(self.output_layer.bias, mean=0, std=0.1).clamp_(min=-0.2, max=0.2)

    def forward(self, x):
        x = x.view([x.shape[0], -1])
        for hidden_layer in self.hidden_layers:
            x = hidden_layer(x)
            x = F.relu(x)
        x = self.output_layer(x)
        return x

class BayesianLayer(nn.Module):
    def __init__(
            self,
            prior_mean: nn.Module,
            posterior_mean: nn.Module,
            prior_stddev: torch.Tensor,
            optimize_prior_mean: bool,
            optimize_prior_rho: bool,
            optimize_posterior_mean: bool,
            optimize_posterior_rho: bool, 
            device='cpu'
    ):
        super().__init__()
        self.device = device 
        assert (str(prior_mean) == str(posterior_mean))  # hacky, incomplete
        self.prior_mean = prior_mean
        self.posterior_mean = posterior_mean
        self.prior_stddev = prior_stddev

        self.prior_rho = deepcopy(prior_mean) # Learns a diagonal matrix for the covariance
        self.posterior_rho = deepcopy(posterior_mean)

        with torch.no_grad():
            # initialize rho
            for rho_layer in [self.prior_rho, self.posterior_rho]:
                for p in rho_layer.parameters():
                    nn.init.constant_(p, inverse_softplus(self.prior_stddev))

            # set requires_grad appropriately
            for (optimize, layer) in zip([optimize_prior_mean,
                                          optimize_prior_rho,
                                          optimize_posterior_mean,
                                          optimize_posterior_rho],
                                         [self.prior_mean,
                                          self.prior_rho,
                                          self.posterior_mean,
                                          self.posterior_rho]):
                if not optimize:
                    for p in layer.parameters():
                        p.requires_grad = False

    def perturb_posterior(self):
        mean_parameters = {name: p for (name, p) in self.posterior_mean.named_parameters()}
        rho_parameters = {name: p for (name, p) in self.posterior_rho.named_parameters()}

        # Samples from standard Gaussian 
        noise_parameters = {name: torch.randn(p.shape, requires_grad=False, device=self.device)
                            for name, p in mean_parameters.items()}

        perturbed_parameters = {name: mean + F.softplus(rho) * noise
                                for name, mean, rho, noise in common_entries(mean_parameters,
                                                                             rho_parameters,
                                                                             noise_parameters)}
        return perturbed_parameters

    def posterior_mean_parameters(self): 
        mean_parameters = {name: p for (name, p) in self.posterior_mean.named_parameters()}
        return mean_parameters
    
    def extract_parameters(self):
        return {name: torch.cat([p.view(-1) for p in layer.parameters()])
                for (name, layer) in zip(['prior_mean', 'prior_rho', 'posterior_mean', 'posterior_rho'],
                                         [self.prior_mean, self.prior_rho, self.posterior_mean, self.posterior_rho])}

    def kl(self):
        parameter_vectors = self.extract_parameters()
        return kl_between_gaussians(
            p_mean=parameter_vectors['posterior_mean'],
            p_var=var_of_rho(parameter_vectors['posterior_rho']),
            q_mean=parameter_vectors['prior_mean'],
            q_var=var_of_rho(parameter_vectors['prior_rho'])
        ).sum()

class BayesianLinear(BayesianLayer):
    def __init__(self, device='cpu', *args, **kwargs):
        super().__init__(device=device, *args, **kwargs)
        self.params = self.posterior_mean_parameters() 

    def forward(self, x):
        return F.linear(x, self.params.get('weight'), self.params.get('bias'))   

class BayesianMLP(nn.Module):
    def __init__(self, 
                 mlp_prior_mean: MLP, # these are standard MLPs; we should write code for them 
                 mlp_posterior_mean: MLP, 
                 device='cpu', 
                 **bayesian_kwargs
            ):
        super().__init__()
        # **bayesian_kwargs refer to which parameters are optimized and which aren't 
        self.hidden_layers: List[BayesianLinear] = nn.ModuleList() 
        for hidden_layer_prior_mean, hidden_layer_posterior_mean in zip(mlp_prior_mean.hidden_layers,
                                                                        mlp_posterior_mean.hidden_layers):
            self.hidden_layers.append(
                BayesianLinear(
                    prior_mean=hidden_layer_prior_mean,
                    posterior_mean=hidden_layer_posterior_mean,
                    device=device, 
                    **bayesian_kwargs
                )
            )

        self.output_layer = BayesianLinear(
            prior_mean=mlp_prior_mean.output_layer,
            posterior_mean=mlp_posterior_mean.output_layer,
            device=device, 
            **bayesian_kwargs
        )

    def forward(self, x):
        x = x.view([x.shape[0], -1])
        for hidden_layer in self.hidden_layers:
            x = hidden_layer(x)
            x = F.relu(x) # Use ReLU non-linearities between layers 
        x = self.output_layer(x)
        return x
    
    def perturb_params(self): 
        return PerturbedCtx(self) 
    
class PerturbedCtx: 

    def __init__(self, mlp: BayesianMLP): 
        self.mlp = mlp
        self.mlp_params = list()  
        for hidden_layer in self.mlp.hidden_layers: 
            self.mlp_params.append(
                deepcopy(hidden_layer.posterior_mean_parameters())
            ) 
        self.mlp_params.append(
            deepcopy(self.mlp.output_layer.posterior_mean_parameters())    
        )

    def __enter__(self):
        for hidden_layer in self.mlp.hidden_layers:
            hidden_layer.params = hidden_layer.perturb_posterior() 
        self.mlp.output_layer.params = self.mlp.output_layer.perturb_posterior() 
    
    def __exit__(self, *_unused_args): 
        del _unused_args
        for i, hidden_layer in enumerate(self.mlp.hidden_layers): 
            hidden_layer.params = self.mlp_params[i] 
        self.mlp.output_layer.params = self.mlp_params[-1] 

class BayesianPolicyMeta(ForwardPolicyMeta): 

    # These functions should be instantiated for each environment 
    mlp_logit: BayesianMLP 
    mlp_logit_prior: MLP 
    mlp_logit_posterior: MLP 
    bayesian_layers: List[BayesianLayer]
    sample_dataset_mode: bool = False 

    def __init__(self, eps=.3, device='cpu', **bayesian_kwargs): 
        super(BayesianPolicyMeta, self).__init__(eps=eps, device=device)
        self.bayesian_kwargs = bayesian_kwargs 

    def update_bayesian_kwargs(self, **bayesian_kwargs): 
        self.bayesian_kwargs = bayesian_kwargs 
        self.mlp_logit = BayesianMLP(
            self.mlp_logit_prior, self.mlp_logit_posterior, device=self.device, **bayesian_kwargs 
        )

    def update_prior_mean(self, mlp_prior_mean): 
        self.mlp_logit_prior = mlp_prior_mean  
        self.mlp_logit_posterior = deepcopy(self.mlp_logit_prior) 
        self.mlp_logit = BayesianMLP(
            self.mlp_logit_prior, self.mlp_logit_posterior, device=self.device, **self.bayesian_kwargs
        )

    def get_latent_emb(self): 
        raise NotImplementedError 
    
    def get_pol(self):
        raise NotImplementedError 

    def forward(self, state, actions=None, perturb_params=False): 
        if not hasattr(state, 'forward_mask'): 
            state.forward_mask = None  

        if perturb_params:
            with self.mlp_logit.perturb_params():  
                latent_emb = self.get_latent_emb(state) 
                pol, gflows = self.get_pol(latent_emb, state.forward_mask) 
        else: 
            latent_emb = self.get_latent_emb(state) 
            pol, gflows = self.get_pol(latent_emb, state.forward_mask) 

        if actions is None: actions, _ = self.get_actions(pol, state.forward_mask) 
        return actions, torch.log(pol[state.batch_ids, actions]), gflows 
    
    def kl(self):
        kl = 0.0
        if not hasattr(self, 'bayesian_layers'): 
            self.bayesian_layers = self._find_bayesian_layers() 
        for layer in self.bayesian_layers:
            kl += layer.kl()
        return kl

    def _find_bayesian_layers(self):
        bayesian_layers = []

        def __find_bayesian_layers(module):
            if isinstance(module, BayesianLayer):
                bayesian_layers.append(module)
            for m in module.children():
                __find_bayesian_layers(m)

        __find_bayesian_layers(self.mlp_logit)
        return bayesian_layers

    def quad_bound(self, risk, kl, dataset_size, delta):
        log_2_sqrt_n_over_delta = math.log(2 * math.sqrt(dataset_size) / delta)
        fraction = (kl + log_2_sqrt_n_over_delta).div(2 * dataset_size)
        sqrt1 = (risk + fraction).sqrt()
        sqrt2 = fraction.sqrt()
        return (sqrt1 + sqrt2).pow(2)

    def pinsker_bound(self, risk, kl, dataset_size, delta):
        B = (kl + math.log(2 * math.sqrt(dataset_size) / delta)).div(dataset_size)
        return risk + B.div(2).sqrt()

    def inverted_kl_bound(self, risk, kl, dataset_size, delta):
        return torch.min(
            self.quad_bound(risk, kl, dataset_size, delta),
            self.pinsker_bound(risk, kl, dataset_size, delta)
        )

    
# Utils for training 

class TrajectoriesDataLoader: 

    def __init__(self, trajectories, batch_size, shuffle=True): 
        assert len(trajectories)
        (sample, _, _) = trajectories[-1] 

        self.trajectories = trajectories 
        self.batch_size = batch_size 
        self.shuffle = shuffle 

        self.traj_length = len(trajectories) 
        self.num_trajs = sample.batch_size 

        self.indices = torch.arange(self.num_trajs) 
        if shuffle: 
            self.indices = torch.randperm(self.num_trajs) 

        self.current_idx = 0 

    def __iter__(self): 
        return self 

    def __next__(self): 
        if self.current_idx >= self.num_trajs: 
            self.current_idx = 0 
            raise StopIteration 
        indices = self.indices[self.current_idx:self.current_idx+self.batch_size]

        batch_traj = list() 
        for (env_t, actions, env_tp1) in self.trajectories: 
            batch_traj.append(
                (env_t.get(indices), actions[indices], env_tp1.get(indices))  
            )
        
        self.current_idx += self.batch_size 
        return batch_traj 
        pass
    
    def merge(self, dataloader): 
        trajectories = list() 
        copy_self = deepcopy(self) 
        for (
            (env_t, actions, env_tp1), (env_t_new, actions_new, env_tp1_new) 
        ) in zip(self.trajectories, dataloader.trajectories): 
            env_t.merge(env_t_new) 
            env_tp1.merge(env_tp1_new) 
            actions = torch.hstack([actions, actions_new]) 
            trajectories.append(
                (env_t, actions, env_tp1)  
            )
            # pass         
            copy_self.num_trajs = env_t.batch_size 

        copy_self.trajectories = trajectories
        copy_self.indices = torch.arange(copy_self.num_trajs) 
        if copy_self.shuffle: 
            copy_self.indices = torch.randperm(copy_self.num_trajs) 
        return copy_self 
    
def create_bayesian_gfn(config, **bayesian_kwargs): 
    match config.env: 

        case 'sets' | 'bags': 
            from sal.models.sets import BayesianPolicy, BackwardPolicy 
            pf = BayesianPolicy(config.src_size, config.hidden_dim, config.num_layers, 
                                device=config.device, 
                                force_mask_idx=config.force_mask_idx, **bayesian_kwargs).to(config.device) 
            pb = BackwardPolicy(config.device) 
            pass
        case 'sequences': 
            from sal.models.sequences import BayesianPolicy, BackwardPolicy 
            pf = BayesianPolicy(config.seq_size, config.vocab_size * config.src_size, 
                                config.hidden_dim, config.num_layers, 
                                device=config.device, **bayesian_kwargs).to(config.device) 
            pb = BackwardPolicy(config.device) 
            pass 
        case _: 
            raise ValueError
    return GFlowNet(pf, pb, criterion=config.criterion, device=config.device) 
    pass 

def train_gfn_epoch(
    gfn, optim, train_dataloader, pbar  
): 
    num_samples = 0 
    running_loss = 0. 
    for traj in train_dataloader: 
        loss = gfn.evaluate_loss_on_trajectories(traj) 
        optim.zero_grad() 
        loss.backward() 
        optim.step()
        running_loss += loss.detach().cpu().item() 
        num_samples += 1  
        pbar.set_postfix(loss=running_loss/num_samples) 
    return running_loss / num_samples

def train_bayesian_gfn_epoch(
    gfn: GFlowNet, train_dataloader, optim, config, pbar  
): 
    pf: BayesianPolicyMeta = gfn.pf 
    running_bound = 0. 
    num_samples = 0 
    for batch in train_dataloader: 
        surrogate = gfn.evaluate_loss_on_trajectories(batch) 
        surrogate_bound = pf.inverted_kl_bound(
            risk=surrogate, kl=pf.kl(), dataset_size=train_dataloader.num_trajs, delta=config.delta
        )
        optim.zero_grad() 
        surrogate_bound.backward() 
        optim.step()
        running_bound += surrogate_bound.detach().cpu().item() 
        num_samples += 1  
        pbar.set_postfix(bound=running_bound/num_samples) 
    pass 

def get_dataset(config, num_trajectories=int(5e3), log_reward=None): 
    gfn = create_gfn(config) 

    if log_reward is None: 
        log_reward = create_log_reward(config, create_gfn(config)) 
    
    trajectories = list() 
    
    pf: BayesianPolicyMeta = gfn.pf 
    pf.sample_dataset_mode = True 
    with gfn.off_policy(): 
        for _ in range((num_trajectories + config.batch_size) // config.batch_size): 
            idx = 0 
            env = create_env(config, log_reward) 

            while (env.stopped < 1).any(): 
                env_t = deepcopy(env) 
                out = pf(env) 
                env.apply(out[0]) 
                env_tp1 = deepcopy(env) 

                if len(trajectories) > idx:
                    curr_env_t, curr_actions, curr_env_tp1 = trajectories[idx]
                    curr_env_t.merge(env_t) 
                    curr_actions = torch.hstack([curr_actions, out[0]]) 
                    curr_env_tp1.merge(env_tp1) 
                    trajectories[idx] = (curr_env_t, curr_actions, curr_env_tp1) 
                    pass
                else: 
                    trajectories.append(
                        (env_t, out[0], env_tp1) 
                    )

                idx += 1 
    pf.sample_dataset_mode = False 
    num_trajectories = int(5e3) 
    return trajectories, log_reward