import torch.nn as nn
import torch
import numpy as np
import torch.nn.init as init


class RNDModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(RNDModel, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.predictor = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_size),
        )

        self.target = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_size),
        )

        for p in self.modules():
            if isinstance(p, nn.Linear):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

        for param in self.target.parameters():
            param.requires_grad = False

    def forward(self, next_obs):
        target_feature = self.target(next_obs)
        predict_feature = self.predictor(next_obs)

        return predict_feature, target_feature


class DRNDModel(nn.Module):
    def __init__(self, input_size, output_size, num=10):
        super(DRNDModel, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.num_target = num
        self.predictor = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_size),
        )

        self.target = [nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_size),
        ) for _ in range(num)]

        for p in self.modules():
            if isinstance(p, nn.Linear):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

        for t_net in self.target:
            for param in t_net.parameters():
                param.requires_grad = False
        
    def forward(self, next_obs):
        target_feature = torch.zeros(self.num_target,next_obs.shape[0], 64)
        for i,t_net in enumerate(self.target):
            target_feature[i,:,:] = t_net(next_obs)
        
        predict_feature = self.predictor(next_obs)

        return predict_feature, target_feature.to(predict_feature.device)


class CFNModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(CFNModel, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.predictor = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, output_size),
        )

        for p in self.modules():
            if isinstance(p, nn.Linear):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

    def forward(self, next_obs):
        predict_feature = self.predictor(next_obs)
        return predict_feature