from copy import deepcopy
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_v(
    buffer,
    value_function,
    optimizer,
    num_updates,
    target_update_interval=100,
    name="",
):
    target_value_network = deepcopy(value_function)
    target_value_network.eval()

    losses = []
    value_preds = []
    value_targets = []
    for u in tqdm(
        range(num_updates), desc=f"Training Value Function {name}", leave=False
    ):
        observations, next_observations, rewards, dones = buffer.sample()

        observations = torch.tensor(observations, dtype=torch.float32).to(DEVICE)
        next_observations = torch.tensor(next_observations, dtype=torch.float32).to(
            DEVICE
        )
        rewards = torch.tensor(rewards, dtype=torch.float32).reshape(-1, 1).to(DEVICE)
        dones = torch.tensor(dones, dtype=torch.float32).reshape(-1, 1).to(DEVICE)

        with torch.no_grad():
            next_values = target_value_network(next_observations)
            target_values = rewards + 0.99 * (1 - dones) * next_values

        loss, predicted_values = update_vf(
            value_function, observations, target_values, optimizer
        )

        if u % target_update_interval == 0:
            target_value_network.load_state_dict(value_function.state_dict())

        losses.append(loss)
        value_preds.append(np.mean(predicted_values))
        value_targets.append(np.mean(target_values.cpu().numpy()))

    return value_function, losses, value_preds, value_targets


def update_vf(vf, observations, values, opt):
    """Update the value function."""
    opt.zero_grad()
    predicted_values = vf(observations)
    loss = F.mse_loss(predicted_values, values)
    loss.backward()
    opt.step()

    return loss.item(), predicted_values.detach().cpu().numpy()
