import torch.nn as nn
import torch
import gym

from iq_learn.agent import make_agent
from iq_learn.utils.utils import get_irl_reward




class Ensemble(nn.Module):
    def __init__(self, create_net_fn, num_ensembles):
        super().__init__()
        self.nets = nn.ModuleList([create_net_fn() for _ in range(num_ensembles)])

    def forward(self, *argv):
        outs = []
        for net in self.nets:
            net_out = net(*argv)
            outs.append(net_out)

        if isinstance(outs[0], torch.Tensor):
            return torch.stack(outs)
        return outs

class InjectNet(nn.Module):
    def __init__(
        self, base_net, head_net, in_dim, hidden_dim, inject_dim, should_inject
    ):
        super().__init__()
        self.base_net = base_net
        if not should_inject:
            inject_dim = 0
        self.head_net = head_net
        self.inject_layer = nn.Sequential(
            nn.Linear(in_dim + inject_dim, hidden_dim), nn.Tanh()
        )
        self.should_inject = should_inject

    def forward(self, x, inject_x):
        x = self.base_net(x)
        if self.should_inject:
            x = torch.cat([x, inject_x], dim=-1)
        x = self.inject_layer(x)
        x = self.head_net(x)
        return x


class ImgEncoder(nn.Module):
    def __init__(self, input_dim=(4, 19, 19), out_dim=64, args=None):
        super(ImgEncoder, self).__init__()
        self.channels = input_dim[0]
        self.img_width_height = input_dim[1:]
        self.out_dim = out_dim

        # CNN modeled off of Mnih et al.
        self.cnn = nn.Sequential(
            nn.Conv2d(self.channels, 16, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Conv2d(16, 32, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=1),
            nn.ReLU()
        )

        self.fc_layer_inputs = self.cnn_out_dim(self.img_width_height)
        layers = []
        layers.append(nn.Linear(self.fc_layer_inputs, self.out_dim))
        layers.append(nn.ReLU())
        self.fully_connected = nn.Sequential(*layers)

        
        # self.fully_connected = nn.Sequential(
        #     nn.Linear(self.fc_layer_inputs, 64, bias=True),
        #     nn.ReLU(),
        #     nn.Linear(64, self.out_dim))

    def cnn_out_dim(self, input_dim):
        return self.cnn(torch.zeros(1, self.channels, *input_dim)
                        ).flatten().shape[0]

    def forward(self, x, *args):
        cnn_out = self.cnn(x).reshape(-1, self.fc_layer_inputs)
        fc_out = self.fully_connected(cnn_out)
        return fc_out


class IdentityBase(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.input_shape = input_shape

    def net(self, x):
        return x

    @property
    def output_shape(self):
        return self.input_shape

    def forward(self, inputs):
        return inputs


class AbstractRewardModel(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, state, action, next_state, done):
        raise NotImplementedError
    
    def update(self, agent, expert_batch):
        raise NotImplementedError
    
    def save(self, path, suffix=""):
        save_path = f"{path}{suffix}"
        torch.save(self.state_dict(), save_path)
        print(f"Saved model to {save_path}")


class DistReward(AbstractRewardModel):
    def __init__(self, ob_space, ac_space, args=None):
        super(DistReward, self).__init__()
        # check action space type, whether discrete or continuous
        ob_shape = ob_space.shape
        if isinstance(ac_space, gym.spaces.Box):
            ac_shape = ac_space.shape
        else:
            ac_shape = ac_space.n
        self.action_dim = ac_shape
        self.gamma = args.gamma
        self.arc_type = args.reward_gen.model.dist.type
        if self.arc_type == 'proximity':
            base_net = ImgEncoder(input_dim=ob_shape, args=args)
            head_net = nn.Sequential(
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
            create_net_fn = lambda: InjectNet(
                base_net, head_net, in_dim=64, hidden_dim=64, inject_dim=64, should_inject=False
            )
            self.num_ensembles = args.reward_gen.model.dist.num_ensembles
            self.panel_var = args.reward_gen.model.dist.panel_var
            self.mse_before_mean = args.reward_gen.model.dist.mse_before_mean
            self.var_for_both_train_and_eval = args.reward_gen.model.dist.var_for_both_train_and_eval
            if self.var_for_both_train_and_eval:
                assert self.panel_var, "Cannot use var_for_both_train_and_eval without panel_var"
            self.reward_net = Ensemble(create_net_fn, num_ensembles=self.num_ensembles)
        elif self.arc_type == 'iq_critic':
            self.encoder_net = ImgEncoder(input_dim=ob_shape, args=args)
            enc_out_dim = self.encoder_net.out_dim
            self.head_net = nn.Sequential(
                nn.Linear(enc_out_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
        else:
            raise NotImplementedError
        self.optimizer = torch.optim.Adam(self.parameters(), lr=args.reward_gen.model.lr)

    def get_proximity_iq(self, states):
        enc_states = self.encoder_net(states)
        return self.head_net(enc_states)
    

    def get_proximity_prox(self, states, return_mean=True):
        if return_mean:
            return self.reward_net(states, None).mean(dim=0)
        else:
            return self.reward_net(states, None)

    def getV(self, states):
        with torch.no_grad():
            if self.arc_type == 'proximity':
                v = self.get_proximity_prox(states, return_mean=True)
            elif self.arc_type == 'iq_critic':
                v = self.get_proximity_iq(states)
            else:
                raise NotImplementedError
        return v


    # only used when type = proximity
    def get_variance(self, states):
        if self.num_ensembles == 1:
            return torch.zeros([states.shape[0], 1], device=states.device)
        return self.reward_net(states, None).var(dim=0)

    def _get_reward_iq(self, states, next_states):
        prox_state = self.get_proximity_iq(states)
        prox_next_state = self.get_proximity_iq(next_states)
        reward = prox_next_state - prox_state
        return reward
    
    def _get_reward_prox(self, states, next_states, panel_var=False, return_mean=True):
        # in training stage, for mse_before_mean, return_mean is False
        if not return_mean:
            assert panel_var == False, "Cannot return panel variance when return_mean is False"
        prox_state = self.get_proximity_prox(states, return_mean=return_mean)
        prox_next_state = self.get_proximity_prox(next_states, return_mean=return_mean)
        reward = prox_next_state - prox_state
        if panel_var:
            return reward - self.get_variance(next_states)
        return reward

    def forward(self, state, action, next_state, done, eval=True, return_mean=True):
        if self.arc_type == 'proximity':
            return self._get_reward_prox(state, next_state, panel_var=self.panel_var and eval, return_mean=return_mean or eval)
        elif self.arc_type == 'iq_critic':
            return self._get_reward_iq(state, next_state)
    
    def update(self, agent, expert_batch):
        expert_obs, expert_next_obs, expert_action, expert_reward, expert_done = expert_batch
        # get target reward from agent
        irl_reward = get_irl_reward(agent, expert_obs, expert_next_obs, expert_action, expert_done, self.gamma)
        
        if self.arc_type == 'iq_critic':
            # train reward model
            predicted_reward = self.forward(expert_obs, expert_action, expert_next_obs, expert_done, eval=False)
            loss = torch.nn.functional.mse_loss(predicted_reward, irl_reward)
        elif self.arc_type == 'proximity':
            # train reward model
            if self.mse_before_mean:
                predicted_reward = self.forward(expert_obs, expert_action, expert_next_obs, expert_done, eval=False, return_mean=False)
                # replicate irl_reward to match the shape of predicted_reward
                irl_reward = irl_reward.repeat(predicted_reward.shape[0], 1, 1)
                loss = torch.nn.functional.mse_loss(predicted_reward, irl_reward)
            elif self.var_for_both_train_and_eval:
                predicted_reward = self.forward(expert_obs, expert_action, expert_next_obs, expert_done, eval=True, return_mean=True)
                loss = torch.nn.functional.mse_loss(predicted_reward, irl_reward)
            else:
                predicted_reward = self.forward(expert_obs, expert_action, expert_next_obs, expert_done, eval=False, return_mean=True)
                loss = torch.nn.functional.mse_loss(predicted_reward, irl_reward)
        else:
            raise NotImplementedError
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {'reward_loss': loss.item()}



class BasicReward(AbstractRewardModel):
    def __init__(self, ob_space=None, ac_space=None, args=None):
        super(BasicReward, self).__init__()
        ob_shape = ob_space.shape
        if isinstance(ac_space, gym.spaces.Box):
            ac_shape = ac_space.shape
        else:
            ac_shape = ac_space.n
        
        self.action_dim = ac_shape
        self.input_config = args.reward_gen.model.basic.input_config

        self.encoder_net = ImgEncoder(input_dim=ob_shape, args=args)
        enc_out_dim = self.encoder_net.out_dim
        if self.input_config == 'sas':
            self.head_net = nn.Sequential(
                nn.Linear(enc_out_dim * 2, 64),
                nn.ReLU(),
                nn.Linear(64, ac_shape)
            )
        elif self.input_config == 'ss':
            self.head_net = nn.Sequential(
                nn.Linear(enc_out_dim * 2, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
        elif self.input_config == 'sa':
            self.head_net = nn.Sequential(
                nn.Linear(enc_out_dim, 64),
                nn.ReLU(),
                nn.Linear(64, ac_shape)
            )
        elif self.input_config == 's' or self.input_config == 'ns':
            self.head_net = nn.Sequential(
                nn.Linear(enc_out_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
        else:
            raise NotImplementedError
        self.gamma = args.gamma
        self.optimizer = torch.optim.Adam(self.parameters(), lr=args.reward_gen.model.lr)


    def _get_reward(self, state, action, next_state):
        # transform action into one-hot encoding
        if self.input_config == 'sas' or self.input_config == 'ss':
            state = self.encoder_net(state)
            next_state = self.encoder_net(next_state)
            cat_state = torch.cat([state, next_state], dim=1)
        elif self.input_config == 'sa' or self.input_config == 's':
            cat_state = self.encoder_net(state)
        elif self.input_config == 'ns':
            cat_state = self.encoder_net(next_state)
        else:
            raise NotImplementedError
        enc_out = self.head_net(cat_state)

        if self.input_config == 'sas' or self.input_config == 'sa':
            return enc_out.gather(1, action.long())
        elif self.input_config == 'ss' or self.input_config == 's' or self.input_config == 'ns':
            return enc_out
        else:
            raise NotImplementedError


    def forward(self, state, action, next_state, done):
        return self._get_reward(state, action, next_state)


    def update(self, agent, expert_batch):
        
        expert_obs, expert_next_obs, expert_action, expert_reward, expert_done = expert_batch
        # get target reward from agent
        irl_reward = get_irl_reward(agent, expert_obs, expert_next_obs, expert_action, expert_done, self.gamma)
        # train reward model
        predicted_reward = self._get_reward(expert_obs, expert_action, expert_next_obs)
        loss = torch.nn.functional.mse_loss(predicted_reward, irl_reward)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {'reward_loss': loss.item()}




class RegReward(AbstractRewardModel):
    def __init__(self, ob_space=None, ac_space=None, args=None):
        super(RegReward, self).__init__()
        ob_shape = ob_space.shape
        if isinstance(ac_space, gym.spaces.Box):
            ac_shape = ac_space.shape
        else:
            ac_shape = ac_space.n
        
        self.action_dim = ac_shape
        self.reg_type = args.reward_gen.model.reg.type
        self.reg_coef = args.reward_gen.model.reg.coef

        self.encoder_net = ImgEncoder(input_dim=ob_shape, args=args)
        enc_out_dim = self.encoder_net.out_dim
        self.head_net = nn.Sequential(
            nn.Linear(enc_out_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.gamma = args.gamma
        self.optimizer = torch.optim.Adam(self.parameters(), lr=args.reward_gen.model.lr)


    def _get_reward(self, state, next_state):
        # transform action into one-hot encoding
        state = self.encoder_net(state)
        next_state = self.encoder_net(next_state)
        cat_state = torch.cat([state, next_state], dim=1)
        enc_out = self.head_net(cat_state)
        return enc_out


    def forward(self, state, action, next_state, done):
        return self._get_reward(state, next_state)

    def get_reg_loss(self, state, next_state):
        if self.reg_type == 'none':
            return 0
        elif self.reg_type == 'dist_constraint': # the dist between state and next_state should be less then 1
            encoded_state = self.encoder_net(state)
            encoded_next_state = self.encoder_net(next_state)
            dist = torch.norm(encoded_state - encoded_next_state, dim=1)
            return torch.nn.functional.relu(dist - 1).mean()
        elif self.reg_type == 'dim_reduction': # make the dimension of encoded state smaller. loss = dim_1 + dim_2^2 + ...
            def custom_regularization_loss(output):
                regularization_loss = 0
                for i in range(output.shape[1]):  # Assuming output has shape (batch_size, n)
                    regularization_loss += (output[:, i] ** (i + 1)).abs().mean()  # Mean to handle batch
                return regularization_loss
            encoded_state = self.encoder_net(state)
            encoded_next_state = self.encoder_net(next_state)
            return custom_regularization_loss(encoded_state) + custom_regularization_loss(encoded_next_state)



    def update(self, agent, expert_batch):
        
        expert_obs, expert_next_obs, expert_action, expert_reward, expert_done = expert_batch
        # get target reward from agent
        irl_reward = get_irl_reward(agent, expert_obs, expert_next_obs, expert_action, expert_done, self.gamma)
        # train reward model
        predicted_reward = self._get_reward(expert_obs, expert_next_obs)
        reward_loss = torch.nn.functional.mse_loss(predicted_reward, irl_reward)

        # add regularizer loss
        reg_loss = self.reg_coef * self.get_reg_loss(expert_obs, expert_next_obs)
        loss = reward_loss + reg_loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {'reward_loss': reward_loss.item(), 'reg_loss': reg_loss.item()}



class IQGenReward(AbstractRewardModel):
    def __init__(self, ob_space, ac_space, args=None):
        super(IQGenReward, self).__init__()
        self.gamma = args.gamma
        self.iq_model = make_agent(ob_space, ac_space, args, load_agent_path=args.pretrain)

    def forward(self, state, action, next_state, done):
        with torch.no_grad():
            q = self.iq_model.critic(state, action)
            next_v = self.iq_model.getV(next_state)
            y = (1 - done) * self.gamma * next_v
            irl_reward = q - y
        return irl_reward

    def getV(self, state):
        with torch.no_grad():
            v = self.iq_model.getV(state)
        return v








    
