import torch
import torch.nn.functional as F
import numpy as np

class Importance_Sampling():
    def __init__(self, diffusion_model, rw_model, train_lr=3e-4):
        self.model = diffusion_model
        self.rw_model = rw_model
        self.df_optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr, weight_decay=int(1e-4))
        self.rw_optimizer = torch.optim.Adam(rw_model.parameters(), lr=train_lr, weight_decay=int(1e-4))

    def rw_loss(self, rw, conditions, weights):
        predicts = self.rw_model(conditions)
        reward_predicts = predicts[:, 0]
        terminal_probs = predicts[:, 1]
        reward_loss = torch.abs(reward_predicts - rw[:, 0])
        weighted_loss = (reward_loss * weights).mean()
        terminal_loss = F.binary_cross_entropy(terminal_probs, rw[:, 1])
        loss = weighted_loss + terminal_loss
        info = {'reward_loss': reward_loss, 'terminal_loss': terminal_loss}

        return loss, info

    def IS_train(self, batch, weights, tensor=False):

        if tensor:
            x = batch['next_observations']
            conditions = torch.cat((batch['observations'], batch['actions']), dim=1)
        else:
            x = torch.tensor(batch['next_observations'], dtype=torch.float32, device='cuda')
            conditions = torch.tensor(np.concatenate((batch['observations'], batch['actions']), axis=1), device='cuda')
            weights = torch.tensor(weights, device='cuda')
        weights = torch.clamp(weights, min=0, max=10)
        weights[torch.isnan(weights)] = 0
        loss, infos = self.model.loss(x, conditions, weights=weights, impo_samp=True)
        self.df_optimizer.zero_grad()
        loss.backward()
        self.df_optimizer.step()


