import wandb

import numpy as np
import torch
import torch.nn as nn
from onpolicy.utils.util import get_gard_norm, huber_loss, mse_loss
from onpolicy.utils.valuenorm import ValueNorm
from onpolicy.algorithms.utils.util import check

def extract(x, t):
    assert (x.shape[:len(t.shape)] == t.shape), (x.shape, t.shape)
    idx = len(t.shape)
    o_t = t
    t = t.unsqueeze(-1)
    num_repetitions = int(np.prod(x.shape[idx+1:]))
    t = t.unsqueeze(-1).repeat(*([1,] * len(t.shape) + [num_repetitions,]))
    t = t.reshape(*t.shape[:-1], *x.shape[idx+1:])
    y = torch.gather(x, dim=idx, index=t)
    y = y.squeeze(idx)
    return y


class DiffusionBC():
    def __init__(self, 
                 args,
                 policy,
                 device=torch.device("cpu")):


        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        self.algorithm_name = args.algorithm_name
        self.n_timesteps = args.n_timesteps
        self.batch_size = args.batch_size
        self.max_grad_norm = args.max_grad_norm 
        self.hidden_size = args.hidden_size

        assert ((not args.use_recurrent_policy) and (not args.use_naive_recurrent_policy))

        self._use_max_grad_norm = args.use_max_grad_norm

    def train(self, epochs, dataset):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        print("[DEBUG] diff-bc train", epochs, flush=True)
        self.policy.actor.diffusion.update_eta(1.0)

        for e in range(epochs):
            train_info = {}
            train_info['bc_loss'] = 0
            train_info['actor_grad_norm'] = 0

            num_samples = dataset._size // self.batch_size

            print("diff-bc epoch", e, flush=True)

            for sample_idx in range(num_samples):
                observations, actions = dataset.sample(self.batch_size)
                observations = observations[-1]
                actions = actions[-1]
                bc_loss = self.policy.actor.bc_loss(
                                    observations,
                                    np.zeros((self.batch_size, 0, self.hidden_size)),
                                    actions,
                                    np.ones((self.batch_size, 1)), 
                                    None,
                                    np.ones((self.batch_size, 1)),)
                
                self.policy.actor_optimizer.zero_grad()

                bc_loss.backward()

                if self._use_max_grad_norm:
                    actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
                else:
                    actor_grad_norm = get_gard_norm(self.policy.actor.parameters())

                self.policy.actor_optimizer.step()

                train_info['bc_loss'] += bc_loss.item()
                train_info['actor_grad_norm'] += actor_grad_norm
            
            for k in train_info.keys():
                train_info[k] /= num_samples
            
            train_info['eta'] = self.policy.actor.diffusion.eta.item()
            
            print(train_info, flush=True)

        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()
