import random
from tqdm import tqdm
from copy import deepcopy

import gym
import d4rl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
from torch.utils.data import DataLoader

from models import MLP


class RewardFunction(nn.Module):
    def __init__(
        self, state_dim, action_dim,
        hidden_dim=64, num_layers=3, output_dim=1
        ) -> None:
        super(RewardFunction, self).__init__()
        self.input_layer = nn.Linear(state_dim + action_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        x = self.output_layer(x)
        return x


class RewardModel():
    def __init__(
        self, state_dim, action_dim,
        hidden_dim=64, num_layers=3, output_dim=1,
        device='cpu'
        ) -> None:
        self.device = device
        self.model = RewardFunction(
            state_dim, action_dim,
            hidden_dim=hidden_dim, 
            num_layers=num_layers, 
            output_dim=output_dim,
        ).to(self.device)
        
    def __call__(self, state, action):
        return self.model(state, action)
        
    def fit(self, dataset, num_epochs=100, learning_rate=1e-3, batch_size=64):
        self.model.train()
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True
        )
        
        losses = []
        pbar = tqdm(range(num_epochs), desc="Training Reward Model")
        for epoch in pbar:
            for x, s1, a1, s0, a0 in dataloader:
                r1 = self.model(s1.to(self.device), a1.to(self.device)).squeeze(-1)
                r0 = self.model(s0.to(self.device), a0.to(self.device)).squeeze(-1)
                logits = r1.sum(axis=-1) - r0.sum(axis=-1)
                labels = torch.ones_like(logits).to(self.device)
                loss = criterion(logits, labels)
                losses.append(loss.item())
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            pbar.set_description(f"Reward Model loss {loss.item():0.3f}")
            
        self.model.eval()
        return self.model, losses


class RL():
    def __init__(self, actor, critic, device='cpu'):
        self.actor = deepcopy(actor)
        self.critic = deepcopy(critic)
        self.reference = deepcopy(actor)
        self.device = device
    
    def train(self, dataset, num_epochs=10, batch_size=64, learning_rate=1e-3, beta=1.):
        self.actor.to(self.device)
        self.critic.model.to(self.device)
        self.reference.to(self.device)
        self.actor.train()
        
        optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=True
        )
        losses = []
        pbar = tqdm(range(num_epochs), desc="Training RLHF")
        for episode in pbar:
            
            for x, s1, _, s0, _ in dataloader:
                
                s1 = s1.view(-1, s1.shape[-1]).to(self.device)
                s0 = s0.view(-1, s0.shape[-1]).to(self.device)
                
                a1 = self.actor(s1)
                a0 = self.actor(s0)
                a1_ref = self.reference(s1)
                a0_ref = self.reference(s0)
                
                states = torch.cat([s1, s0], dim=0)
                policy_actions = torch.cat([a1, a0], dim=0)
                reference_actions = torch.cat([a1_ref, a0_ref], dim=0)

                policy_log_prob = self.actor.log_prob(states, policy_actions)
                reference_log_prob = self.reference.log_prob(states, reference_actions)
                reward = self.critic(states, policy_actions)
                loss = -torch.mean(reward)
                kl = (policy_log_prob - reference_log_prob).mean()
                loss = loss + beta * kl
            
                optimizer.zero_grad()
                loss.backward()
                losses.append(loss.item())
                optimizer.step()
            
                pbar.set_description(f"RLHF loss {loss.item():0.3f}")
        
        self.actor.eval()
        return self.actor, losses

