from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym.spaces

from .distributions import (
    FixedCategorical,
    FixedNormal,
    Identity,
    MixedDistribution,
)
from .utils import MLP, flatten_ac
from .encoder import Encoder
from utils.pytorch import to_tensor
from utils.logger import logger


class Actor(nn.Module):
    def __init__(self, config, ob_space, ac_space, tanh_policy, encoder=None):
        super().__init__()
        self._config = config
        self._ac_space = ac_space
        self._activation_fn = getattr(F, config.policy_activation)
        self._tanh = tanh_policy
        self._gaussian = config.gaussian_policy

        if encoder:
            self.encoder = encoder
        else:
            if config.pretrained_encoder == 'resnet':
                from torchvision.models import resnet18
                resnet = resnet18(pretrained=True)
                resnet.eval()
                # if not config.is_ft_encoder:
                #     for param in resnet.parameters():
                #         param.requires_grad = False
                self.encoder = resnet
                self.encoder.output_dim = resnet.fc.out_features * config.frame_stack
            elif config.pretrained_encoder == 'r3m':
                from r3m.r3m import load_r3m
                r3m = load_r3m("resnet50") # resnet18, resnet34
                r3m.eval()
                for param in r3m.parameters():
                    param.requires_grad = False

                self.encoder = r3m # output [1, 2048]
                self.encoder.output_dim = self.encoder.module.outdim * config.frame_stack
            elif config.pretrained_encoder == 'vae':
                from .encoder import Decoder
                self.encoder = Encoder(config, ob_space)
                self.decoder = Decoder(config, ob_space)
            else:
                self.encoder = Encoder(config, ob_space)

        input_dim = self.encoder.output_dim

        self.fc = MLP(
            config, input_dim, config.policy_mlp_dim[-1], config.policy_mlp_dim[:-1]
        )

        # add task ID for input
        input_dim_ = config.policy_mlp_dim[-1]

        if self._config.demo_conditioned_policy:
            self.lstm = nn.LSTM(
                input_size=input_dim,
                hidden_size=config.lstm_hidden_dim,
                num_layers=config.lstm_num_layers,
                batch_first=True,
            )
            if config.lstm_output_dim == config.lstm_hidden_dim:
                self.lstm_embed = nn.Identity()
            else:
                self.lstm_embed = nn.Linear(
                    in_features=config.lstm_hidden_dim,
                    out_features=config.lstm_output_dim,
                )
            input_dim_ += config.lstm_output_dim

        elif self._config.with_taskID:
            input_dim_ += config.num_task

        # only fine-tune the last actor layer, freeze all the previous layers
        # if self._config.ft_option == "output_layer_only":
        #     for param in self.encoder.parameters():
        #         param.requires_grad = False
        #     for param in self.fc.parameters():
        #         param.requires_grad = False
        # test if the layers are frozen
        # for param in self.encoder.parameters():
        #    print(param)
        # for param in self.fc.parameters():
        #    print(param)

        self.fcs = nn.ModuleDict()
        self._dists = {}
        for k, v in ac_space.spaces.items():
            if isinstance(
                v, gym.spaces.Box
            ):  # and self._gaussian:  # for convenience to transfer bc policy
                self.fcs.update({k: MLP(config, input_dim_, gym.spaces.flatdim(v) * 2)})
            else:
                self.fcs.update({k: MLP(config, input_dim_, gym.spaces.flatdim(v))})

            # for param in self.fcs[k].parameters():
            #     print(param)

            if isinstance(v, gym.spaces.Box):
                if self._gaussian:
                    self._dists[k] = lambda m, s: FixedNormal(m, s)
                else:
                    self._dists[k] = lambda m, s: Identity(m)
            else:
                self._dists[k] = lambda m, s: FixedCategorical(logits=m)

    @property
    def info(self):
        return {}

    def forward(self, ob: dict, detach_conv=False, task_id=None, demo=None):
        """
        if self._config.pretrained_encoder in ['resnet', 'r3m']:
            encoder_outputs = []
            for i in range(self._config.frame_stack):
                temp_img = ob['ob'][:, i*3:i*3+3, :, :]
                if self._config.pretrained_encoder == 'r3m':
                    if not (temp_img.max() > 1.0):
                        temp_img = temp_img.float() * 255.0
                encoder_outputs.append(self.encoder(temp_img))
            out = torch.cat(encoder_outputs, dim=-1)
        else:
        """
        out = self.encoder(ob, detach_conv=detach_conv)
        out = self._activation_fn(self.fc(out))
        """
        if self._config.with_taskID and task_id is not None:
            _to_tensor = lambda x: to_tensor(x, self._config.device)
            task_id = _to_tensor(task_id)
            out = torch.cat([out, task_id], dim=-1)

        ## process demo
        if self._config.demo_conditioned_policy and demo is not None:
            if isinstance(demo, list):
                if isinstance(demo[0], OrderedDict) or isinstance(demo[0], dict):
                    for i in range(len(demo)):
                        demo[i] = list(demo[i].values())
                        if len(demo[i][0].shape) != 1:
                            demo[i] = [torch.squeeze(x) for x in demo[i]]
                        demo[i] = torch.cat(demo[i], dim=0)

                demo = torch.stack(demo, dim=0)
                ### ASDF: what is this?
                if self._config.batch_size == 1:
                    demo = torch.unsqueeze(demo, dim=0)

            if isinstance(demo, OrderedDict) or isinstance(demo, dict):
                demo = list(demo.values())
                if len(demo[0].shape) == 2:
                    demo = [x.unsqueeze(0) for x in demo]
                demo = torch.cat(demo, dim=-1)

            bs = demo.shape[0]
            self.c0 = torch.zeros(
                self._config.lstm_num_layers, bs, self._config.lstm_hidden_dim
            ).to(self._config.device)
            self.h0 = torch.zeros(
                self._config.lstm_num_layers, bs, self._config.lstm_hidden_dim
            ).to(self._config.device)

            demo_out, _ = self.lstm(demo, (self.h0, self.c0))
            demo_out = self.lstm_embed(demo_out)
            out = torch.cat([out, demo_out[:, -1, :]], dim=-1)
        """
        
        means, stds = OrderedDict(), OrderedDict()
        for k, v in self._ac_space.spaces.items():
            if isinstance(v, gym.spaces.Box):  # and self._gaussian:
                mean, log_std = self.fcs[k](out).chunk(2, dim=-1)
                log_std_min, log_std_max = (
                    self._config.log_std_min,
                    self._config.log_std_max,
                )
                log_std = torch.tanh(log_std)
                log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (
                    log_std + 1
                )
                std = log_std.exp()
            else:
                mean, std = self.fcs[k](out), None

            means[k] = mean
            stds[k] = std

        return means, stds

    def act(
        self,
        ob,
        demo=None,
        deterministic=False,
        activations=None,
        return_log_prob=False,
        detach_conv=False,
        task_id=None,
        ac=None,
    ):
        """Samples action for rollout."""
        """
        if self._config.demo_conditioned_policy and demo is None:
            raise Exception

        # if self._config.with_taskID and task_id is not None:
        #     means, stds = self.forward(
        #         ob, detach_conv=detach_conv, task_id=task_id, demo=demo
        #     )
        # else:
        """

        means, stds = self.forward(ob, detach_conv=detach_conv, demo=demo)

        dists = OrderedDict()
        for k in means.keys():
            dists[k] = self._dists[k](means[k], stds[k])

        actions = OrderedDict()
        mixed_dist = MixedDistribution(dists)
        if activations is None:
            if deterministic:
                activations = mixed_dist.mode()
            else:
                activations = mixed_dist.rsample()

        if return_log_prob and self._gaussian:
            if ac is not None:
                log_probs = mixed_dist.log_probs(ac)
            else:
                log_probs = mixed_dist.log_probs(activations)

        for k, v in self._ac_space.spaces.items():
            z = activations[k]
            if self._tanh and isinstance(v, gym.spaces.Box):
                action = torch.tanh(z)
                if return_log_prob and self._gaussian:
                    # follow the Appendix C. Enforcing Action Bounds
                    log_det_jacobian = 2 * (np.log(2.0) - z - F.softplus(-2.0 * z)).sum(
                        dim=-1, keepdim=True
                    )
                    log_probs[k] = log_probs[k] - log_det_jacobian
            else:
                action = z

            actions[k] = action

        if return_log_prob and self._gaussian:
            log_probs = torch.cat(list(log_probs.values()), -1).sum(-1, keepdim=True)
            entropy = mixed_dist.entropy()
        else:
            log_probs = None
            entropy = None

        return actions, activations, log_probs, entropy


class Critic(nn.Module):
    def __init__(
        self, config, ob_space, ac_space=None, encoder=None, discrete_ac=False
    ):
        super().__init__()
        self._config = config
        self._discrete_ac = discrete_ac

        if encoder:
            self.encoder = encoder
        else:
            self.encoder = Encoder(config, ob_space)

        input_dim = self.encoder.output_dim
        output_dim = 1

        if ac_space is not None:
            if discrete_ac:
                output_dim = gym.spaces.flatdim(ac_space)
            else:
                input_dim += gym.spaces.flatdim(ac_space)

        self.fcs = nn.ModuleList()

        for _ in range(config.critic_ensemble):
            self.fcs.append(MLP(config, input_dim, output_dim, config.critic_mlp_dim))

    def forward(self, ob, ac=None, detach_conv=False):
        out = self.encoder(ob, detach_conv=detach_conv)

        if not self._discrete_ac and ac is not None:
            out = torch.cat([out, flatten_ac(ac)], dim=-1)
        assert len(out.shape) == 2

        out = [fc(out) for fc in self.fcs]
        if len(out) == 1:
            out = out[0]

        # if self._discrete_ac and ac is not None:
        #     if isinstance(ac, dict):
        #         ac = ac["ac"].long()
        #     out = out[ac]  ### TODO: how to get index
        return out
