# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from drqv2 import Actor, Encoder, Predictor, RandomShiftsAug
import utils


class BCBPRAgent:
    def __init__(self, obs_shape, action_shape, max_action, device, lr, feature_dim,
                 hidden_dim, critic_target_tau, num_expl_steps,
                 update_every_steps, stddev_schedule, stddev_clip, use_tb,
                 augmentation=RandomShiftsAug(pad=4)):
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.stddev_clip = stddev_clip

        # models
        self.encoder = Encoder(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
                           hidden_dim).to(device)
        
        self.predictor = Predictor(self.encoder.repr_dim, *action_shape, max_action).to(device)
        # optimizers
        self.encoder_opt = torch.optim.Adam(list(self.encoder.parameters())+list(self.predictor.parameters()), lr=lr)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)

        # data augmentation
        self.aug = augmentation

        self.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        obs = self.encoder(obs.unsqueeze(0))
        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        if eval_mode:
            action = dist.mean
        else:
            action = dist.sample(clip=None)
            if step < self.num_expl_steps:
                action.uniform_(-1.0, 1.0)
        return action.cpu().numpy()[0]

    def update_encoder(self, obs, step, action):
        metrics = dict()
        enc = self.encoder(obs)
        behavioural_action = self.predictor(enc)
        
        behavioural_action = F.normalize(behavioural_action, dim=1, p=2)
        action_ = F.normalize(action, dim=1, p=2)
        
        encoder_loss = F.mse_loss(action_, behavioural_action)

        # optimize actor and encoder
        self.encoder_opt.zero_grad(set_to_none=True)
        encoder_loss.backward()
        self.encoder_opt.step()

        # if self.use_tb:
        #     metrics['encoder_loss'] = encoder_loss.item()

        return metrics
    
    def update_actor(self, obs, step, behavioural_action=None):
        metrics = dict()

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)

        # offline BC Loss
        actor_loss = F.mse_loss(action, behavioural_action)

        # optimize actor and encoder
        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        if self.use_tb:
            metrics['actor_logprob'] = log_prob.mean().item()
            metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
            metrics['actor_bc_loss'] = actor_loss.item()

        return metrics
    
    
    def pretrain(self, replay_buffer, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_buffer)
        obs, action, reward, _, _ = utils.to_torch(
            batch, self.device)

        # augment
        obs = self.aug(obs.float())
        
        metrics.update(self.update_encoder(obs, step, action.detach()))

        return metrics
    
    
    def update(self, replay_buffer, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_buffer)
        obs, action, reward, _, _ = utils.to_torch(
            batch, self.device)

        # augment
        obs = self.aug(obs.float())
        
        # metrics.update(self.update_encoder(obs, step, action.detach()))
        # encode
        obs = self.encoder(obs)
        if self.use_tb:
            metrics['batch_reward'] = reward.mean().item()

        # update actor
        
        metrics.update(self.update_actor(obs.detach(), step, action.detach()))

        return metrics
