from collections import OrderedDict

import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

from utils.normalizer import Normalizer
from utils.pytorch import to_tensor, center_crop, center_crop_images

from algorithms.base_agent import BaseAgent
from .expert_dataset import ExpertDataset, ExpertDataset_H_steps
import algorithms

### TODO: transfer all IL algos to BaseILAgent

class BaseILAgent(BaseAgent):
    """ Base class for imitation learning agents. """

    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        super().__init__(config, ob_space, ac_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        self.layout = layout

        if hasattr(self._config, "gail_rl_algo") and self._config.gail_rl_algo is not None:
            self._rl_agent = algorithms.get_agent_by_name(self._config.gail_rl_algo)(config, ob_space, ac_space,
                                                                                     env_ob_space, layout)
        if self._discrete:
            self._bc_loss_fn = nn.CrossEntropyLoss()
        else:
            self._bc_loss_fn = nn.MSELoss()
        self._bc_data_iter = None
        self._bc_data_loader = None

        # expert dataset
        if config.is_train and config.algo not in ["gail-v2", "reachable_gail"]:
            self._dataset, self._data_loader, self._data_iter, self.sampled_few_demo_index = self._load_dataset(config.demo_path)

            # HACK: todo: do we need this? this load all the target demos
            # if config.evaluate_bc_test_loss:
            #     self._bc_test_dataset = ExpertDataset(
            #         config.demo_path,
            #         config.demo_subsample_interval,
            #         ac_space,
            #         use_low_level=config.demo_low_level,
            #         sample_range_start=config.demo_sample_range_start,
            #         sample_range_end=config.demo_sample_range_end,
            #         num_task=config.num_task,
            #         target_taskID=config.target_taskID,
            #         target_demo_path=config.target_demo_path,
            #         frame_stack=self._config.frame_stack if self._config.encoder_type == "cnn" else None,
            #     )
            #     self._bc_test_data_loader = torch.utils.data.DataLoader(
            #         self._bc_test_dataset,
            #         batch_size=self._config.batch_size * 4,  # why not
            #         shuffle=True,
            #         drop_last=True,
            #     )
            #     self._bc_test_data_iter = iter(self._bc_test_data_loader)

            if config.bc_loss_coeff > 0.0:
                self._rl_agent.set_bc_data_iter(self._data_iter)
                self._rl_agent.set_bc_data_loader(self._data_loader)

            self.set_buffer(self._rl_agent._buffer)

    def _pretrain_encoder_V2(self, encoder, decoder, _train_loader, _train_iter):
        encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=self._config.encoder_lr)
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=self._config.encoder_lr)
        loss_list = []
        num_task = len(_train_loader)
        for i in range(self._config.encoder_pretrain_steps):    
            j = i%num_task
            try:
                data = next(_train_iter[j])
            except StopIteration:
                _train_iter[j] = iter(_train_loader[j])
                data = next(_train_iter[j])
            ob = data["ob"]
            ob = self.normalize(ob)
            ob = to_tensor(ob, self._config.device)
            if self._config.encoder_type == "cnn":
                ob = {"ob": ob}
                for k, v in ob.items():
                    if len(v["ob"].shape) == 4:
                        ob[k] = center_crop_images(v["ob"], self._config.encoder_image_size)
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            z = encoder(ob)
            recon = decoder(z)
            loss = F.mse_loss(recon, ob["ob"])

            loss_list.append(loss.cpu().item())
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
            
        return encoder
    
    
    def _sample_expert_data(self, test=False):
        if test:
            try:
                bc_data = next(self._bc_test_data_iter)
            except StopIteration:
                self._bc_test_data_iter = iter(self._bc_test_data_loader)
                bc_data = next(self._bc_test_data_iter)
        else:
            try:
                bc_data = next(self._bc_data_iter)
            except StopIteration:
                self._bc_data_iter = iter(self._bc_data_loader)
                bc_data = next(self._bc_data_iter)

        return bc_data

    def _bc_loss(self, bc_data):
        # pre-process observations
        o = bc_data["ob"]
        o = self.normalize(o)

        # convert double tensor to float32 tensor
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        ac = _to_tensor(bc_data["ac"])
        if isinstance(ac, OrderedDict):
            ac = list(ac.values())
            if len(ac[0].shape) == 1:
                ac = [x.unsqueeze(0) for x in ac]
            ac = torch.cat(ac, dim=-1)

        if hasattr(self, "_rl_agent"):
            actor = self._rl_agent._actor
        else:
            actor = self._actor
        # the actor loss
        if self._config.with_taskID:

            pred_ac, _ = actor(ob=o, task_id=bc_data["id"])
        else:
            pred_ac, _ = actor(o)

        if isinstance(pred_ac, OrderedDict):
            pred_ac = list(pred_ac.values())
            if len(pred_ac[0].shape) == 1:
                pred_ac = [x.unsqueeze(0) for x in pred_ac]
            pred_ac = torch.cat(pred_ac, dim=-1)

        if self._discrete:
            actor_loss = self._bc_loss_fn(pred_ac, ac.long()[0])
        else:
            actor_loss = self._bc_loss_fn(ac, pred_ac)

        return actor_loss

    def calculate_bc_test_loss(self):
        return self._bc_loss(self._sample_expert_data(test=False))

    def set_reward_function(self, predict_reward):
        self._predict_reward = predict_reward

    def set_bc_data_iter(self, bc_data_iter):
        self._bc_data_iter = bc_data_iter

    def set_bc_data_loader(self, bc_data_loader):
        self._bc_data_loader = bc_data_loader


    def state_dict(self):
        return {
            "rl_agent": self._rl_agent.state_dict()
        }

    def load_state_dict(self, ckpt):
        if "rl_agent" in ckpt:
            self._rl_agent.load_state_dict(ckpt["rl_agent"])
        else:
            self._rl_agent.load_state_dict(ckpt)
            # self._network_cuda(self._config.device)
            return

    def store_episode(self, rollouts, step=0):
        self._rl_agent.store_episode(rollouts, step=step)

    def sync_networks(self):
        self._rl_agent.sync_networks()
