import random
from tqdm import tqdm
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
from torch.utils.data import DataLoader


class DPO():
    def __init__(self, policy, reference, device='cpu'):
        self.device = device
        self.policy = deepcopy(policy)
        self.reference = deepcopy(reference)
        
    def _compute_log_prob(self, policy, state, action):
        assert len(state.shape) == 3, "states must be 3D tensors."
        _, seg_len, state_dim = state.shape
        state = state.to(self.device)
        action = action.to(self.device)
        return torch.cat([
            policy.log_prob(state[:, i, :], action[:, i, :])
            for i in range(seg_len)
        ], dim=-1)
    
    def fit(self, dataset, num_epochs=10, batch_size=64, learning_rate=1e-3, beta=1., label_smoothing=0.):
        self.policy.to(self.device)
        self.reference.to(self.device)
        
        optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=True
        )
        losses = []
        pbar = tqdm(range(num_epochs), desc="Training DPO")
        for epoch in pbar:
            for x, s1, a1, s0, a0 in dataloader:
                
                pi_logratios = (
                    self._compute_log_prob(self.policy, s1, a1) 
                    - self._compute_log_prob(self.policy, s0, a0)
                )
                ref_log_ratios = (
                    self._compute_log_prob(self.reference, s1, a1)
                    - self._compute_log_prob(self.reference, s0, a0)
                )
                logits = pi_logratios - ref_log_ratios
                loss = (
                    -F.logsigmoid(beta * logits) * (1 - label_smoothing) 
                    - F.logsigmoid(-beta * logits) * label_smoothing
                ).mean()
                
                optimizer.zero_grad()
                loss.backward()
                losses.append(loss.cpu().item())
                optimizer.step()
                
                pbar.set_description(f"DPO loss {loss.item():0.3f}")
            
        self.policy.eval()
        return self.policy, losses
            