from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import gym

from .base_agent import BaseAgent
from .expert_dataset import ExpertDataset
from networks import Actor
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    compute_gradient_norm,
    compute_weight_norm,
    sync_networks,
    sync_grads,
    to_tensor,
    sample_from_dataloader,
    center_crop_images,
    apply_random_augmentations,
    center_crop,
)


class BCAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        super().__init__(config, ob_space, ac_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        self._discrete = np.all([isinstance(v, gym.spaces.Discrete) for v in self._ac_space.spaces.values()])

        self._epoch = 0

        self._actor = Actor(config, ob_space, ac_space, config.tanh_policy)
        self._network_cuda(config.device)
        if self._discrete:
            assert (len(self._ac_space.spaces.values()) == 1)
            self._n_actions = [v.n for v in self._ac_space.spaces.values()][0]
            self._actor_loss = nn.CrossEntropyLoss()
        else:
            self._actor_loss = nn.MSELoss()
        self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.bc_lr)
        self._actor_lr_scheduler = StepLR(
            self._actor_optim, step_size=self._config.max_global_step // 5, gamma=0.5,
        )

        self._log_creation()

    def load_training_data(self, sampled_few_demo_index=None):
        if self._config.is_train:
            self._dataset, self._train_loader, _, self.sampled_few_demo_index = self._load_dataset(self._config.demo_path, sampled_few_demo_index=sampled_few_demo_index)
            if self._config.demo_conditioned_policy:
                self._demo_dataset, self._demo_loader, self._demo_iter = self._load_traj(self._config.demo_path, sampled_few_demo_index=self._dataset.sampled_few_demo_index)

            if self._config.multitask_bc_loss_coeff>0.0 or self._config.pretrain_BC:
                self._other_dataset, self._other_data_loader, self._other_data_iter = self._load_dataset(self._config.other_demo_path)
            
            # if config.pretrained_encoder == 'vae':
            #     self._actor.encoder = self._pretrain_encoder(self._actor.encoder, self._actor.decoder, self._train_loader, self._other_data_loader)

            if self._config.val_split != 0:
                raise NotImplementedError

                dataset_size = len(self._dataset)
                indices = list(range(dataset_size))
                split = int(np.floor((1 - self._config.val_split) * dataset_size))
                train_indices, val_indices = indices[split:], indices[:split]
                train_sampler = SubsetRandomSampler(train_indices)
                val_sampler = SubsetRandomSampler(val_indices)
                self._train_loader = torch.utils.data.DataLoader(
                    self._dataset,
                    batch_size=self._config.batch_size,
                    sampler=train_sampler,
                )
                self._val_loader = torch.utils.data.DataLoader(
                    self._dataset,
                    batch_size=self._config.batch_size,
                    sampler=val_sampler,
                )

    
    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a BC agent")
            logger.info("The actor has %d parameters", count_parameters(self._actor))

    def state_dict(self):
        return {
            "actor_state_dict": self._actor.state_dict(),
            "actor_optim_state_dict": self._actor_optim.state_dict(),
            "ob_norm_state_dict": self._ob_norm.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._actor.load_state_dict(ckpt["actor_state_dict"])
        self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"])
        self._network_cuda(self._config.device)

        self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"])
        optimizer_cuda(self._actor_optim, self._config.device)

    def _network_cuda(self, device):
        self._actor.to(device)

    def sync_networks(self):
        sync_networks(self._actor)

    def pretrain_bc(self, step=0):
        train_info = Info()
        for transitions in self._other_data_loader:
            _train_info = self._update_network(transitions, train=True, step=step)
            train_info.add(_train_info)
        
        return train_info.get_dict(only_scalar=True)

    def update_normalizer_pretrain(self, obs=None):
        """ Updates normalizers for multi-task demos. """
        if self._config.ob_norm:
            for obs in self._other_data_loader:
                super().update_normalizer(obs)


    def train(self, step=0):
        train_info = Info()
        for transitions in self._train_loader:
            # if self._config.demo_conditioned_policy:
            #     expert_demo_o, expert_demo_ac = self._sample_expert_traj()
            #     ### ASDF: Do we need demo actions?
            #     transitions["demo"] = expert_demo_o
            _train_info = self._update_network(transitions, train=True, step=step)
            train_info.add(_train_info)
        self._epoch += 1
        self._actor_lr_scheduler.step()

        train_info.add(
            {
                "actor_grad_norm": compute_gradient_norm(self._actor),
                "actor_weight_norm": compute_weight_norm(self._actor),
            }
        )
        train_info = train_info.get_dict(only_scalar=True)
        logger.info("BC loss %f", train_info["actor_loss"])
        return train_info

    def evaluate(self):
        if self._val_loader:
            eval_info = Info()
            for transitions in self._val_loader:
                _eval_info = self._update_network(transitions, train=False)
                eval_info.add(_eval_info)
            self._epoch += 1
            return eval_info.get_dict(only_scalar=True)
        logger.warning("No validation set available, make sure '--val_split' is set")
        return None

    def _multitask_bc_loss(self):
        # if self._config.demo_conditioned_policy:
        #     raise NotImplementedError

        try:
            try:
                other_expert_data = next(self._other_data_iter)
            except StopIteration:
                self._other_data_iter = iter(self._other_data_loader)
                other_expert_data = next(self._other_data_iter)

            # pre-process observations
            o = other_expert_data["ob"]
            o = self.normalize(o)

            # convert double tensor to float32 tensor
            _to_tensor = lambda x: to_tensor(x, self._config.device)
            o = _to_tensor(o)
            ac = _to_tensor(other_expert_data["ac"])
            if isinstance(ac, OrderedDict):
                ac = list(ac.values())
                if len(ac[0].shape) == 1:
                    ac = [x.unsqueeze(0) for x in ac]
                ac = torch.cat(ac, dim=-1)

            # the actor loss

            ### TODO: If demo_conditioned_policy, need to sample demo

            # if self._config.with_taskID:
            #     pred_ac, _ = self._actor(ob=o, task_id=other_expert_data["id"])
            # else:
            pred_ac, _ = self._actor(o)

            if isinstance(pred_ac, OrderedDict):
                pred_ac = list(pred_ac.values())
                if len(pred_ac[0].shape) == 1:
                    pred_ac = [x.unsqueeze(0) for x in pred_ac]
                pred_ac = torch.cat(pred_ac, dim=-1)

            # if self._discrete:
            #     actor_loss = self._actor_loss(pred_ac, ac.long()[0])
            # else:
            actor_loss = self._actor_loss(ac, pred_ac)

        except ValueError:
            print("No other expert data")

        return actor_loss

    def _update_network(self, transitions, train=True, step=0):
        info = Info()

        # pre-process observations
        o = transitions["ob"]
        o = self.normalize(o)

        # convert double tensor to float32 tensor
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        ac = _to_tensor(transitions["ac"])
        if isinstance(ac, OrderedDict):
            ac = list(ac.values())
            if len(ac[0].shape) == 1:
                ac = [x.unsqueeze(0) for x in ac]
            ac = torch.cat(ac, dim=-1)

        # if self._config.encoder_type == "cnn":
        #     # change o to channel first
        #     o = {"ob": o}
        #     for k, v in o.items():
        #         if self._config.encoder_type == "cnn" and len(v['ob'].shape) == 4:
        #             o[k] = center_crop_images(v['ob'], self._config.encoder_image_size)

        #     # randomly do zero or one or two or three operation in flip, rotate, cutout
        #     if self._config.img_augment:
        #         o['ob'] = apply_random_augmentations(o['ob'], self._config.device)
        
        # the actor loss

        # if self._config.demo_conditioned_policy:
        #     pred_ac, _ = self._actor(o, demo=transitions["demo"])

        # elif self._config.with_taskID:
        #     pred_ac, _ = self._actor(ob=o, task_id=transitions["id"])
        # else:
        pred_ac, _ = self._actor(o)

        if isinstance(pred_ac, OrderedDict):
            pred_ac = list(pred_ac.values())
            if len(pred_ac[0].shape) == 1:
                pred_ac = [x.unsqueeze(0) for x in pred_ac]
            pred_ac = torch.cat(pred_ac, dim=-1)

        # print(f"ac: {ac}, pred ac: {pred_ac}")

        # if self._discrete:
        #     actor_loss = self._actor_loss(pred_ac, ac.long()[0])
        # else:
        ###  Continuous Actions
        diff = ac - pred_ac
        actor_loss = self._actor_loss(ac, pred_ac)
        diff = torch.sum(torch.abs(diff), axis=0).cpu()
        for i in range(diff.shape[0]):
            info["action" + str(i) + "_L1loss"] = diff[i].mean().item()


        info["actor_loss"] = actor_loss.cpu().item()
        info["pred_ac"] = pred_ac.cpu().detach()
        info["GT_ac"] = ac.cpu()


        # blend the policy loss with the other demos' BC loss
        if self._config.multitask_bc_loss_coeff > 0.0:
            if step < self._config.bc_step:
                multitask_bc_loss_coeff = self._config.multitask_bc_loss_coeff*10.0
            else:
                multitask_bc_loss_coeff = self._config.multitask_bc_loss_coeff

            bc_loss = multitask_bc_loss_coeff * self._multitask_bc_loss()
            actor_loss += bc_loss
            info["bc_loss"] = bc_loss.cpu().item()


        if train:
            # update the actor
            self._actor_optim.zero_grad()
            actor_loss.backward()
            # torch.nn.utils.clip_grad_norm_(self._actor.parameters(), self._config.max_grad_norm)
            sync_grads(self._actor)
            self._actor_optim.step()

        return mpi_average(info.get_dict(only_scalar=True))


    def reset(self):
        if self._config.demo_conditioned_policy:
            self._expert_demo_o, self._expert_demo_a = self._sample_expert_traj(num_samples=1)

    # TODO: overwrite base_agent act method to sample input demo to actor
    def act(self, ob, is_train=True, target_taskID=None, return_log_prob=False):
        """ sample expert demo to input to policy """

        ob = self.normalize(ob)

        ob = ob.copy()
        for k, v in ob.items():
            if self._config.encoder_type == "cnn" and len(v.shape) == 3:
                    ob[k] = center_crop(v, self._config.encoder_image_size)
            else:
                ob[k] = np.expand_dims(ob[k], axis=0)

        with torch.no_grad():
            ob = to_tensor(ob, self._config.device)
            if self._config.demo_conditioned_policy:
                ac, activation, log_probs, _ = self._actor.act(ob, deterministic=not is_train, return_log_prob=return_log_prob, task_id=target_taskID, demo=self._expert_demo_o)
            else:
                ac, activation, log_probs, _ = self._actor.act(ob, deterministic=not is_train, return_log_prob=return_log_prob, task_id=target_taskID)

        for k in ac.keys():
            if self._discrete:
                ac[k] = ac[k].cpu().item()
            else:
                ac[k] = ac[k].cpu().numpy().squeeze(0)
            activation[k] = activation[k].cpu().numpy().squeeze(0)
        if return_log_prob and self._config.gaussian_policy:
            log_probs = log_probs.cpu().numpy().squeeze(0)
        else:
            log_probs = None

        return ac, activation, log_probs

    def _sample_expert_traj(self, demo_loader=None, num_samples=None):
        expert_demo = sample_from_dataloader(demo_loader or self._demo_loader, num_samples or self._config.batch_size,
                                      self._config.batch_size)
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        expert_demo_o = _to_tensor(self.normalize(expert_demo["ob"]))
        expert_demo_ac = _to_tensor(expert_demo["ac"])
        return expert_demo_o, expert_demo_ac
