from tqdm import tqdm
from copy import deepcopy
from warnings import filterwarnings
filterwarnings("ignore")

import torch


class Noise():
    def sample(self, num_samples=1):
        return torch.randn(num_samples, 100)


class PretrainedModelWrapper():
    def __init__(self, model, device='cpu'):
        self.model = deepcopy(model).to(device)
        self.noise = Noise()
        self.device = device

    def __call__(self, z, x=None):
        if x is None:
            x = torch.randint(0, 10, (z.size(0),))
            x = x.to(self.device)
        return self.model(z, x)
    
    def to(self, device):
        self.device = device
        self.model.to(device)
        return self
    
    def train(self):
        self.model.train()
        return self
    
    def eval(self):
        self.model.eval()
        return self
    
    def sample(self, num_samples=1, condition=None):
        if condition is None:
            condition = torch.randint(0, 10, (num_samples,))
            condition = condition.to(self.device)
        noise = torch.randn(num_samples, 100, device=self.device)
        samples = self.model(noise, condition)
        return samples


class RL():
    def __init__(self, actor, critic, device='cpu'):
        self.actor = deepcopy(actor)
        self.critic = deepcopy(critic)
        self.reference = deepcopy(actor)
        self.device = device
    
    def train(self, num_epochs=200, batch_size=64, learning_rate=1e-3, beta=0.):
        self.actor.to(self.device)
        self.critic.to(self.device)
        self.actor.train()
        noise = self.actor.noise
        optimizer = torch.optim.Adam(self.actor.model.parameters(), lr=learning_rate)

        losses = []
        pbar = tqdm(range(num_epochs))
        for episode in pbar:
            z = noise.sample(num_samples=batch_size)
            z = z.to(self.device)
            labels = torch.randint(0, 10, (batch_size,))
            labels = labels.to(self.device)
            action = self.actor(z, labels).view(-1, 1, 28, 28)
            reward = self.critic(action, labels)
            loss = -torch.mean(reward)

            if beta:
                refenrence_action = self.reference(z)
                loss = loss + beta * torch.mean((action - refenrence_action) ** 2)

            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            pbar.set_description(f"loss {loss.item():0.3f}")
        
        self.actor.eval()
        return self.actor, losses
            
    def sample(self, num_samples=1, condition=None):
        self.actor.eval()
        with torch.no_grad():
            samples = self.actor.sample(
                num_samples=num_samples,
                condition=condition
            )
        return samples


class DPO():
    def __init__(self, policy, reference, device='cpu'):
        self.device = device
        self.policy = deepcopy(policy)
        self.reference = deepcopy(reference)

    def density(self, z, mu=0, sigma=1):
        z = z.unsqueeze(-1).unsqueeze(-1)
        pdf = (1 / (sigma * torch.sqrt(torch.tensor(2 * torch.pi)))) * \
            torch.exp(-0.5 * ((z - mu) / sigma) ** 2)
        return pdf
            
    def fit(self, reward_function, num_epochs=100, batch_size=64, learning_rate=1e-3, beta=0.):
        self.policy.to(self.device)
        self.reference.to(self.device)
        reward_function.to(self.device)
        noise = self.policy.noise
        
        optimizer = torch.optim.Adam(self.policy.model.parameters(), lr=learning_rate)
        losses = []
        pbar = tqdm(range(num_epochs))
        for epoch in pbar:
            
            z = noise.sample(num_samples=batch_size)
            z = z.to(self.device)      
            labels = torch.randint(0, 10, (batch_size,))
            labels = labels.to(self.device)
            y1 = self.policy(z, labels).view(-1, 1, 28, 28)
            y0 = self.reference(z, labels).view(-1, 1, 28, 28)
            
            r1 = reward_function(y1, labels)
            r0 = reward_function(y0, labels)
            
            preference = torch.exp(r1) / (torch.exp(r1) + torch.exp(r0) + 1e-10)
            p = torch.log(preference) / (1 - torch.log(preference))
            loss = -torch.mean(p * self.density(z))

            if beta:
                loss = loss + beta * torch.mean((y1 - y0) ** 2)
            
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            pbar.set_description(f"loss {loss.item():0.3f}")
        
        self.policy.eval()
        return self.policy, losses
            
    def sample(self, num_samples=1, condition=None):
        self.policy.eval()
        with torch.no_grad():
            samples = self.policy.sample(
                num_samples=num_samples,
                condition=condition
            )
        return samples