from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym

from .utils import MLP
from utils.pytorch import to_tensor
from .encoder import Encoder

class DiscriminatorEnsemble(nn.Module):

    def __init__(self, ensemble_size, config, ob_space, ac_space=None, encoder=None):
        super().__init__()
        self.discriminators = nn.ModuleList()
        for _ in range(ensemble_size):
            self.discriminators.append(Discriminator(config, ob_space, ac_space, encoder))


    def forward(self, ob, ac=None, ID=None, output_mean=True):
        outs, temps = [], []

        for dis in self.discriminators:
            out, temp = dis(ob, ac, ID)
            outs.append(out)
            temps.append(temp)

        if output_mean:
            outs = torch.mean(torch.stack(outs), dim=0)
        else:
            outs = torch.stack(outs)

        return outs, temps

# encoder original from https://github.com/rohitrango/BC-regularized-GAIL/blob/master/a2c_ppo_acktr/algo/gail.py#L266
class Discriminator(nn.Module):
    def __init__(self, config, ob_space, ac_space=None, encoder=None):
        super().__init__()
        self._config = config
        self._no_action = ac_space == None
        if encoder:
            self.encoder = encoder
            input_dim = config.output_dim
            #for param in self.encoder.parameters():
            #    param.requires_grad = False
        else:
            if config.encoder_type == "cnn":
                self.encoder = Encoder(config, ob_space)
                input_dim = self.encoder.output_dim
                
            else:
                self.encoder = None
                input_dim = gym.spaces.flatdim(ob_space)
        
        if not self._no_action:
            input_dim += gym.spaces.flatdim(ac_space)
            # if self._config.dis_with_taskID:
            #     input_dim += self._config.num_task

        # if not config.one_hot_vector:
        self.fc = MLP(
            config,
            input_dim,
            1,
            config.discriminator_mlp_dim,
            getattr(F, config.discriminator_activation),
        )
        # else: 
        #     self.fc = MLP(
        #     config,
        #     input_dim,
        #     config.num_task,
        #     config.discriminator_mlp_dim,
        #     getattr(F, config.discriminator_activation),
        #     )


    def forward(self, ob, ac=None, ID=None):
        temp = None
        if self.encoder is not None:
            ob = self.encoder(ob)
       
        # flatten observation
        if isinstance(ob, OrderedDict) or isinstance(ob, dict):
            ob = list(ob.values())
            if len(ob[0].shape) == 1:
                ob = [x.unsqueeze(0) for x in ob]
            ob = torch.cat(ob, dim=-1)

        if ac is not None:
            # flatten action
            if isinstance(ac, OrderedDict) or isinstance(ac, dict):
                ac = list(ac.values())
                if len(ac[0].shape) == 1:
                    ac = [x.unsqueeze(0) for x in ac]
                ac = torch.cat(ac, dim=-1)
            ob = torch.cat([ob, ac], dim=-1)

        if ID is not None:
            if isinstance(ID, list):
                ID = torch.stack(ID).reshape((-1, self._config.num_task))
            ob = torch.cat([ob, ID], dim=-1)

        out = self.fc(ob)
        return out, temp

class DiscriminatorV2(nn.Module):
    """GAIL-V2 (DVD): work for long trajectory; (s, a) vs. sequence of (s', a')"""
    def __init__(self, config, ob_space, ac_space=None, encoder=None):
        super().__init__()
        self._config = config
        self._no_action = ac_space == None
        # input_dim = gym.spaces.flatdim(ob_space)
        if encoder:
            self.encoder = encoder
            input_dim = config.output_dim
            #for param in self.encoder.parameters():
            #    param.requires_grad = False
        else:
            if config.encoder_type == "cnn":
                self.encoder = Encoder(config, ob_space)
                input_dim = self.encoder.output_dim
            else:
                self.encoder = None
                input_dim = gym.spaces.flatdim(ob_space)

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=config.lstm_hidden_dim,
            num_layers=config.lstm_num_layers,
            batch_first=True,
        )
        
        if not self._no_action:
            input_dim += gym.spaces.flatdim(ac_space)
        
        self.encoder_mlp = MLP(config, input_dim, config.lstm_hidden_dim)

        self.fc = MLP(
            config,
            config.lstm_hidden_dim*2,
            1,
            config.discriminator_mlp_dim,
            getattr(F, config.discriminator_activation),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, ob, ac=None, ob2=None, ac2=None):
        if self.encoder is not None:
            ob = self.encoder(ob)

            # encode ob2
            if isinstance(ob2, list):
                ob2 = [self.encoder(x) for x in ob2]
            else:
                en_list = []
                traj_len = ob2['ob'].shape[1]
                for i in range(ob2['ob'].shape[0]):
                    # stack the image observations
                    # TODO: try a different way to stack imag: stack close n images and keep the traj_length num of data
                    input_ob = ob2['ob'][i:i+1, :, :, :, :].squeeze().clone()
                    concate_list = []
                    for j in range(0, traj_len, self._config.frame_stack):
                        if j + self._config.frame_stack <= traj_len:
                            temp_ob = input_ob[j:j+self._config.frame_stack, :, :, :]
                        else:
                            temp_ob = input_ob[traj_len-self._config.frame_stack:traj_len, :, :, :]

                        concatenated_ob = torch.cat(tuple(temp_ob), dim=0)
                        concate_list.append(concatenated_ob)

                    in_ob = torch.stack(concate_list, dim=0)
                    en_ob = self.encoder({'ob': in_ob})
                    en_list.append(en_ob)

                ob2 = torch.stack(en_list, dim=0)

        # flatten observation
        if isinstance(ob, OrderedDict) or isinstance(ob, dict):
            ob = list(ob.values())
            if len(ob[0].shape) == 1:
                ob = [x.unsqueeze(0) for x in ob]
            ob = torch.cat(ob, dim=-1)

        if ac is not None and not self._no_action:
            # flatten action
            if isinstance(ac, OrderedDict) or isinstance(ac, dict):
                ac = list(ac.values())
                if len(ac[0].shape) == 1:
                    ac = [x.unsqueeze(0) for x in ac]
                ac = torch.cat(ac, dim=-1)
            ob = torch.cat([ob, ac], dim=-1)
        
        # flatten observation
        if isinstance(ob2, list):
            if isinstance(ob2[0], OrderedDict) or isinstance(ob2[0], dict):
                for i in range(len(ob2)):
                    ob2[i] = list(ob2[i].values())
                    if len(ob2[i][0].shape) != 1:
                        ob2[i] = [torch.squeeze(x) for x in ob2[i]]
                    ob2[i] = torch.cat(ob2[i], dim=0)
                
            ob2 = torch.stack(ob2, dim=0)
            if self._config.batch_size == 1:
                ob2 = torch.unsqueeze(ob2, dim=0)
        
        if isinstance(ob2, OrderedDict) or isinstance(ob2, dict):
            ob2 = list(ob2.values())
            if len(ob2[0].shape) == 2:
                ob2 = [x.unsqueeze(0) for x in ob2]
            ob2 = torch.cat(ob2, dim=-1)
        
        # if ac2 is not None:
        #     # flatten action
        #     if isinstance(ac2[0], OrderedDict) or isinstance(ac2[0], dict):
        #         for i in range(len(ac2)):
        #             ac2[i] = list(ac2[i].values())
        #             if len(ac2[i][0].shape) != 1:
        #                 ac2[i] = [torch.squeeze(x) for x in ac2[i]]
        #             ac2[i] = torch.cat(ac2[i], dim=0)
        #         ac2 = torch.stack(ac2, dim=0)
        
        out1 = self.encoder_mlp(ob)
        
        bs = ob2.shape[0]
        self.c0 = torch.zeros(self._config.lstm_num_layers, bs, self._config.lstm_hidden_dim).to(self._config.device)
        self.h0 = torch.zeros(self._config.lstm_num_layers, bs, self._config.lstm_hidden_dim).to(self._config.device)

        out2, _ = self.lstm(ob2, (self.h0, self.c0))
        out = self.fc(torch.cat([out1, out2[:, -1, :]], dim=-1)) # input: (batch_size, lstm_hidden_dim*2), output: (batch_size, output_dim)
        #out = self.sigmoid(out)
        return out

class RelativeDiscriminator(nn.Module):
    ### Not used currently, was for an old idea, relative GAIL discriminator
    """D(demo, transition1, transition2) --> Does transition 1 solve demo task better than transition 2?"""

    def __init__(self, config, ob_space, ac_space=None):
        super().__init__()
        self._config = config
        self._no_action = ac_space == None
        input_dim = gym.spaces.flatdim(ob_space)

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=config.lstm_hidden_dim,
            num_layers=config.lstm_num_layers,
            batch_first=True,
        )

        if not self._no_action:
            input_dim += gym.spaces.flatdim(ac_space)

        self.encoder = MLP(config, input_dim, config.lstm_hidden_dim)

        self.fc = MLP(
            config,
            config.lstm_hidden_dim * 3,
            1,
            config.discriminator_mlp_dim,
            getattr(F, config.discriminator_activation),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, ob1, ob2, ac1=None, ac2=None, demo_ob=None, demo_ac=None):
        # flatten observation
        if isinstance(ob1, OrderedDict) or isinstance(ob1, dict):
            ob1 = list(ob1.values())
            if len(ob1[0].shape) == 1:
                ob1 = [x.unsqueeze(0) for x in ob1]
            ob1 = torch.cat(ob1, dim=-1)

        if isinstance(ob2, OrderedDict) or isinstance(ob2, dict):
            ob2 = list(ob2.values())
            if len(ob2[0].shape) == 1:
                ob2 = [x.unsqueeze(0) for x in ob2]
            ob2 = torch.cat(ob2, dim=-1)

        if ac1 is not None:
            # flatten action
            if isinstance(ac1, OrderedDict) or isinstance(ac1, dict):
                ac1 = list(ac1.values())
                if len(ac1[0].shape) == 1:
                    ac1 = [x.unsqueeze(0) for x in ac1]
                ac1 = torch.cat(ac1, dim=-1)
            ob1 = torch.cat([ob1, ac1], dim=-1)

        if ac2 is not None:
            # flatten action
            if isinstance(ac2, OrderedDict) or isinstance(ac2, dict):
                ac2 = list(ac2.values())
                if len(ac2[0].shape) == 1:
                    ac2 = [x.unsqueeze(0) for x in ac2]
                ac2 = torch.cat(ac2, dim=-1)
            ob2 = torch.cat([ob2, ac2], dim=-1)

        # flatten observation
        if isinstance(demo_ob, list):
            if isinstance(demo_ob[0], OrderedDict) or isinstance(demo_ob[0], dict):
                for i in range(len(demo_ob)):
                    demo_ob[i] = list(demo_ob[i].values())
                    if len(demo_ob[i][0].shape) != 1:
                        demo_ob[i] = [torch.squeeze(x) for x in demo_ob[i]]
                    demo_ob[i] = torch.cat(demo_ob[i], dim=0)

            demo_ob = torch.stack(demo_ob, dim=0)
            if self._config.batch_size == 1:
                demo_ob = torch.unsqueeze(demo_ob, dim=0)

        if isinstance(demo_ob, OrderedDict) or isinstance(demo_ob, dict):
            demo_ob = list(demo_ob.values())
            if len(demo_ob[0].shape) == 1:
                demo_ob = [x.unsqueeze(0) for x in demo_ob]
            demo_ob = torch.cat(demo_ob, dim=-1)

        # if ac2 is not None:
        #     # flatten action
        #     if isinstance(ac2[0], OrderedDict) or isinstance(ac2[0], dict):
        #         for i in range(len(ac2)):
        #             ac2[i] = list(ac2[i].values())
        #             if len(ac2[i][0].shape) != 1:
        #                 ac2[i] = [torch.squeeze(x) for x in ac2[i]]
        #             ac2[i] = torch.cat(ac2[i], dim=0)
        #         ac2 = torch.stack(ac2, dim=0)

        out1 = torch.cat([self.encoder(ob1), self.encoder(ob2)], dim=-1)

        bs = demo_ob.shape[0]
        self.c0 = torch.zeros(self._config.lstm_num_layers, bs, self._config.lstm_hidden_dim).to(self._config.device)
        self.h0 = torch.zeros(self._config.lstm_num_layers, bs, self._config.lstm_hidden_dim).to(self._config.device)

        out2, _ = self.lstm(demo_ob, (self.h0, self.c0))
        out = self.fc(torch.cat([out1, out2[:, -1, :]],
                                dim=-1))  # input: (batch_size, lstm_hidden_dim*2), output: (batch_size, output_dim)
        # out = self.sigmoid(out)
        return out

class ACDiscriminator(nn.Module):
    def __init__(self, config, ob_space, ac_space=None):
        super().__init__()
        self._config = config
        self._no_action = ac_space == None

        input_dim = gym.spaces.flatdim(ob_space)
        if not self._no_action:
            input_dim += gym.spaces.flatdim(ac_space)
        # if self._config.dis_with_taskID:
        #     input_dim += self._config.num_task
        
        output_dim1 = 1
        if self._config.label_taskID:
            output_dim2 = self._config.num_task
        if self._config.label_goal_obs:
            output_dim2 = self._config.num_task*2

        self.fc = MLP(
            config,
            input_dim,
            config.discriminator_output_dim,
            config.discriminator_mlp_dim,
            getattr(F, config.discriminator_activation),
        )
        self.fc1 = MLP(config, config.discriminator_output_dim, output_dim1)
        self.fc2 = MLP(config, config.discriminator_output_dim, output_dim2)

    def forward(self, ob, ac=None, ID=None):
        # flatten observation
        if isinstance(ob, OrderedDict) or isinstance(ob, dict):
            ob = list(ob.values())
            if len(ob[0].shape) == 1:
                ob = [x.unsqueeze(0) for x in ob]
            ob = torch.cat(ob, dim=-1)

        if ac is not None:
            # flatten action
            if isinstance(ac, OrderedDict) or isinstance(ac, dict):
                ac = list(ac.values())
                if len(ac[0].shape) == 1:
                    ac = [x.unsqueeze(0) for x in ac]
                ac = torch.cat(ac, dim=-1)
            ob = torch.cat([ob, ac], dim=-1)

        if ID is not None:
            if isinstance(ID, list):
                ID = torch.stack(ID).reshape((-1, self._config.num_task))
            ob = torch.cat([ob, ID], dim=-1)

        out = self.fc(ob)
        out1 = self.fc1(out)
        out2 = self.fc2(out)

        return out1, out2

class AIRLDiscriminator(nn.Module):
    def __init__(self, config, ob_space, ac_space=None):
        super().__init__()
        self._config = config

        input_dim = gym.spaces.flatdim(ob_space)
        # if self._config.dis_with_taskID:
        #     input_dim += self._config.num_task

        self.g = MLP(
            config,
            input_dim,
            1,
            config.discriminator_mlp_dim_r,
            getattr(F, config.discriminator_activation),
        )

        self.h = MLP(
            config,
            input_dim,
            1,
            config.discriminator_mlp_dim_v,
            getattr(F, config.discriminator_activation),
        )
        self.gamma = config.gamma

    def f(self, ob, done, ob_next, ID):
        if ID is not None:
            ob = torch.cat([ob, ID], dim=-1)
            ob_next = torch.cat([ob_next, ID], dim=-1)

        rs = self.g(ob)
        vs = self.h(ob)
        next_vs = self.h(ob_next)
        return rs + self.gamma * (1 - done) * next_vs - vs

    def forward(self, ob, done, log_pis, ob_next, ID=None):
        # flatten observation
        if isinstance(ob, OrderedDict) or isinstance(ob, dict):
            ob = list(ob.values())
            if len(ob[0].shape) == 1:
                ob = [x.unsqueeze(0) for x in ob]
            ob = torch.cat(ob, dim=-1)

        if isinstance(ob_next, OrderedDict) or isinstance(ob_next, dict):
            ob_next = list(ob_next.values())
            if len(ob_next[0].shape) == 1:
                ob_next = [x.unsqueeze(0) for x in ob_next]
            ob_next = torch.cat(ob_next, dim=-1)

        if isinstance(done, list):
            done = torch.stack(done).reshape((-1, 1))
        if isinstance(log_pis, list):
            log_pis = torch.stack(log_pis).reshape((-1, 1))
        if ID is not None and isinstance(ID, list):
            ID = torch.stack(ID).reshape((-1, self._config.num_task))

        # Discriminator's output is sigmoid(f - log_pi).
        return self.f(ob, done, ob_next, ID) - log_pis

    def calculate_reward(self, ob, done, log_pis, ob_next):
        with torch.no_grad():
            logits = self.forward(ob, done, log_pis, ob_next)
            return logits


