from collections import OrderedDict

import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

from utils.normalizer import Normalizer
from .expert_dataset import ExpertDataset, ExpertDataset_H_steps

from utils.pytorch import to_tensor, center_crop, center_crop_images
import matplotlib.pyplot as plt

class BaseAgent(object):
    """Base class for agents."""

    def __init__(self, config, ob_space, ac_space):
        self._config = config

        self._ob_norm = Normalizer(
            ob_space, default_clip_range=config.clip_range, clip_obs=config.clip_obs
        )
        self._discrete = np.all(
            [isinstance(v, gym.spaces.Discrete) for v in ac_space.spaces.values()]
        )
        if self._discrete:
            assert len(ac_space.spaces.values()) == 1
            self._n_actions = [v.n for v in ac_space.spaces.values()][0]
            self._bc_loss_fn = nn.CrossEntropyLoss()
        else:
            self._bc_loss_fn = nn.MSELoss()

        self._buffer = None
        self._bc_data_iter = None
        self._bc_data_loader = None

    def _pretrain_encoder(self, encoder, decoder, _train_loader, _other_data_loader):
        ite, load = [], []
        ite.append(iter(_train_loader))
        load.append(_train_loader)
        ite.append(iter(_other_data_loader))
        load.append(_other_data_loader)
        
        encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=self._config.encoder_lr)
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=self._config.encoder_lr)
        loss_list = []
        for i in range(self._config.encoder_pretrain_steps):    
            j = i%2 #0 #np.random.choice(len(ite), size=1)[0]
            try:
                data = next(ite[j])
            except StopIteration:
                ite[j] = iter(load[j])
                data = next(ite[j])
            ob = data["ob"]
            ob = self.normalize(ob)
            ob = to_tensor(ob, self._config.device)
            if self._config.encoder_type == "cnn":
                ob = {"ob": ob}
                for k, v in ob.items():
                    if len(v["ob"].shape) == 4:
                        ob[k] = center_crop_images(v["ob"], self._config.encoder_image_size)
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            z = encoder(ob)
            recon = decoder(z)
            loss = F.mse_loss(recon, ob["ob"])
            # save recon and ob["ob"] as image
            # if i % 1000 == 0:
            #     img1 = recon[0].cpu().detach().numpy().transpose(1, 2, 0)
            #     img2 = ob["ob"][0].cpu().detach().numpy().transpose(1, 2, 0)
            #     img1s = np.concatenate([img1[:,:,0:3], img1[:,:,3:6], img1[:,:,6:9]], axis=1) 
            #     img2s = np.concatenate([img2[:,:,0:3], img2[:,:,3:6], img2[:,:,6:9]], axis=1)
            #     concate = np.concatenate([img1s, img2s], axis=1)
            #     plt.imshow(concate)
            #     plt.savefig(self._config.log_dir + "/recon_ob"+str(i)+".png")

            loss_list.append(loss.cpu().item())
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

        # plt the loss_list and save the image
        # plt.plot(loss_list)
        # plt.savefig(self._config.log_dir + "/loss_list.png")

        # frozen the encoder
        # if not self._config.is_ft_encoder:
        #     for param in encoder.parameters():
        #         param.requires_grad = False
        return encoder

    def _ac_one_hot(self, ac):
        for k, v in ac.items():
            ac[k] = F.one_hot(v.long(), num_classes=self._n_actions).float()
            if len(ac[k].shape) == 3:
                ac[k] = ac[k].squeeze(-2)
        return ac

    def normalize(self, ob):
        """Normalizes observations."""
        if self._config.ob_norm:
            return self._ob_norm.normalize(ob)

        if self._config.encoder_type == "cnn":
            # do image normalization [0,1] TODO: [-1, 1]
            def ob_norm(v):
                if not isinstance(v, dict):
                    return v / 255.0
                return OrderedDict([(k, v_ / 255.0) for k, v_ in v.items()])

            if isinstance(ob, list):
                return [ob_norm(x) for x in ob]
            else:
                return ob_norm(ob)

        return ob

    def act(self, ob, is_train=True, target_taskID=None, return_log_prob=False):
        """Returns action and the actor's activation given an observation @ob."""
        if hasattr(self, "_rl_agent"):
            return self._rl_agent.act(
                ob,
                is_train,
                target_taskID=target_taskID,
                return_log_prob=return_log_prob,
            )

        ob = self.normalize(ob)

        ob = ob.copy()
        for k, v in ob.items():
            if self._config.encoder_type == "cnn" and len(v.shape) == 3:
                ob[k] = center_crop(v, self._config.encoder_image_size)
            else:
                ob[k] = np.expand_dims(ob[k], axis=0)

        with torch.no_grad():
            ob = to_tensor(ob, self._config.device)
            ac, activation, log_probs, _ = self._actor.act(
                ob,
                deterministic=not is_train,
                return_log_prob=return_log_prob,
                task_id=target_taskID,
            )

        for k in ac.keys():
            if self._discrete:
                ac[k] = ac[k].cpu().item()
            else:
                ac[k] = ac[k].cpu().numpy().squeeze(0)
            activation[k] = activation[k].cpu().numpy().squeeze(0)
        if return_log_prob and self._config.gaussian_policy:
            log_probs = log_probs.cpu().numpy().squeeze(0)
        else:
            log_probs = None

        return ac, activation, log_probs

    def compute_kl(self, ob, target_taskID=None):
        """Returns action and the actor's activation given an observation @ob."""
        if hasattr(self, "_rl_agent"):
            return self._rl_agent.compute_kl(ob, target_taskID=target_taskID)

        ob = self.normalize(ob)

        ob = ob.copy()
        for k, v in ob.items():
            if self._config.encoder_type == "cnn" and len(v.shape) == 3:
                ob[k] = center_crop(v, self._config.encoder_image_size)
            else:
                ob[k] = np.expand_dims(ob[k], axis=0)

        o_p = ob.copy()
        dim = 4  # change later for different envs
        target_data = {}
        for k, v in ob.items():
            if isinstance(v, dict):
                sub_keys = v.keys()
                target_data[k] = {sub_key: v[sub_key][:dim] for sub_key in sub_keys}
            else:
                target_data[k] = v[:, :dim]
        ob = target_data

        with torch.no_grad():
            ob = to_tensor(ob, self._config.device)
            o_p = to_tensor(o_p, self._config.device)
            pred_mean, pred_std = self._default_actor(ob, task_id=target_taskID)
            mean, std = self._actor(ob=o_p, task_id=target_taskID)

        def un_oderdic(temp):
            if isinstance(temp, OrderedDict):
                temp = list(temp.values())
                if len(temp[0].shape) == 1:
                    temp = [x.unsqueeze(0) for x in temp]
                temp = torch.cat(temp, dim=-1)
            return temp

        mean, std, pred_mean, pred_std = (
            un_oderdic(mean),
            un_oderdic(std),
            un_oderdic(pred_mean),
            un_oderdic(pred_std),
        )
        KL = (
            torch.log(std / pred_std)
            + (pred_std**2 + (pred_mean - mean) ** 2) / (2 * std**2)
            - 0.5
        )
        KL = torch.mean(KL, dim=0)
        KL = KL.mean()

        pred_mean, KL = pred_mean.cpu().numpy(), KL.cpu().numpy()

        return pred_mean, KL

    def update_normalizer(self, obs=None):
        """Updates normalizers."""
        if self._config.ob_norm:
            if obs is None:
                for i in range(len(self._dataset)):
                    self._ob_norm.update(self._dataset[i]["ob"])
                self._ob_norm.recompute_stats()
            else:
                self._ob_norm.update(obs)
                self._ob_norm.recompute_stats()

    def _sample_expert_data(self, test=False):
        if test:
            try:
                bc_data = next(self._bc_test_data_iter)
            except StopIteration:
                self._bc_data_iter = iter(self._bc_test_data_loader)
                bc_data = next(self._bc_test_data_iter)
        else:
            try:
                bc_data = next(self._bc_data_iter)
            except StopIteration:
                self._bc_data_iter = iter(self._bc_data_loader)
                bc_data = next(self._bc_data_iter)

        return bc_data

    def _bc_loss(self, bc_data):
        # pre-process observations
        o = bc_data["ob"]
        o = self.normalize(o)

        # convert double tensor to float32 tensor
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)

        if self._config.encoder_type == "cnn":
            o = {"ob": o}
            for k, v in o.items():
                if len(v["ob"].shape) == 4:
                    o[k] = center_crop_images(v["ob"], self._config.encoder_image_size)

        ac = _to_tensor(bc_data["ac"])
        if isinstance(ac, OrderedDict):
            ac = list(ac.values())
            if len(ac[0].shape) == 1:
                ac = [x.unsqueeze(0) for x in ac]
            ac = torch.cat(ac, dim=-1)

        # the actor loss
        if self._config.with_taskID:
            pred_ac, _ = self._actor(ob=o, task_id=bc_data["id"])
        else:
            pred_ac, _ = self._actor(o)

        if isinstance(pred_ac, OrderedDict):
            pred_ac = list(pred_ac.values())
            if len(pred_ac[0].shape) == 1:
                pred_ac = [x.unsqueeze(0) for x in pred_ac]
            pred_ac = torch.cat(pred_ac, dim=-1)

        if self._discrete:
            actor_loss = self._bc_loss_fn(pred_ac, ac.long()[0])
        else:
            actor_loss = self._bc_loss_fn(ac, pred_ac)

        return actor_loss

    def store_episode(self, rollouts, step=0):
        """Stores @rollouts to replay buffer."""
        raise NotImplementedError()

    def is_off_policy(self):
        return self._buffer is not None

    def set_buffer(self, buffer):
        self._buffer = buffer

    def replay_buffer(self):
        return self._buffer.state_dict()

    def load_replay_buffer(self, state_dict):
        self._buffer.load_state_dict(state_dict)

    def set_reward_function(self, predict_reward):
        self._predict_reward = predict_reward

    def set_bc_data_iter(self, bc_data_iter):
        self._bc_data_iter = bc_data_iter

    def set_bc_data_loader(self, bc_data_loader):
        self._bc_data_loader = bc_data_loader

    def sync_networks(self):
        raise NotImplementedError()

    def train(self, step=0):
        raise NotImplementedError()

    def _soft_update_target_network(self, target, source, tau):
        for target_param, source_param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                (1 - tau) * source_param.data + tau * target_param.data
            )

    def _copy_target_network(self, target, source):
        self._soft_update_target_network(target, source, 0)

    def _load_dataset(self, path, batch_size=None, sampled_few_demo_index=None):
        _dataset = ExpertDataset(
            path,
            self._config.demo_subsample_interval,
            self._ac_space,
            use_low_level=self._config.demo_low_level,
            sample_range_start=self._config.demo_sample_range_start,
            sample_range_end=self._config.demo_sample_range_end,
            num_task=self._config.num_task,
            target_taskID=self._config.target_taskID,
            num_target_demos=self._config.num_target_demos,
            target_demo_path=self._config.target_demo_path,
            is_sqil=(self._config.algo == "sqil"),
            frame_stack=(
                self._config.frame_stack if self._config.encoder_type == "cnn" else None
            ),
            encoder_image_size=self._config.encoder_image_size,
            sampled_few_demo_index = sampled_few_demo_index,
        )
        _data_loader = torch.utils.data.DataLoader(
            _dataset,
            batch_size=batch_size or self._config.batch_size,
            shuffle=True,
            drop_last=True,
        )
        _data_iter = iter(_data_loader)
        return _dataset, _data_loader, _data_iter, _dataset.sampled_few_demo_index

    def _load_traj(self, path, sampled_few_demo_index=None, batch_size=None):
        _dataset = ExpertDataset_H_steps(
            path,
            self._config.traj_interval,  # subsample_interval
            self._ac_space,
            use_low_level=self._config.demo_low_level,
            sample_range_start=self._config.demo_sample_range_start,
            sample_range_end=self._config.demo_sample_range_end,
            num_task=self._config.num_task,
            target_taskID=None,
            step_size=self._config.traj_length,
            batch_size=batch_size or self._config.batch_size,
            num_target_demos=self._config.num_target_demos,
            target_demo_path=self._config.target_demo_path,
            sampled_few_demo_index=sampled_few_demo_index,
            encoder_image_size= (self._config.encoder_image_size if self._config.encoder_type == "cnn" else None),
        )
        _data_loader = torch.utils.data.DataLoader(
            _dataset,
            batch_size=batch_size or self._config.batch_size,
            shuffle=True,
            drop_last=True,
        )
        _data_iter = iter(_data_loader)
        return _dataset, _data_loader, _data_iter
