import torch
import torch.nn as nn

import os, sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from utils_TransTEE.utils import get_initialiser
from transformers import CLIPModel, CLIPProcessor

class Truncated_power():
    def __init__(self, degree, knots):
        """
        This class construct the truncated power basis; the data is assumed in [0,1]
        :param degree: int, the degree of truncated basis
        :param knots: list, the knots of the spline basis; two end points (0,1) should not be included
        """
        self.degree = degree
        self.knots = knots
        self.num_of_basis = self.degree + 1 + len(self.knots)
        self.relu = nn.ReLU(inplace=True)

        if self.degree == 0:
            print('Degree should not set to be 0!')
            raise ValueError

        if not isinstance(self.degree, int):
            print('Degree should be int')
            raise ValueError

    def forward(self, x):
        """
        :param x: torch.tensor, batch_size * 1
        :return: the value of each basis given x; batch_size * self.num_of_basis
        """
        # x = x.squeeze()
        out = torch.zeros(x.shape[0], self.num_of_basis)
        for _ in range(self.num_of_basis):
            if _ <= self.degree:
                if _ == 0:
                    out[:, _] = 1.
                else:
                    out[:, _] = x**_
            else:
                if self.degree == 1:
                    out[:, _] = (self.relu(x - self.knots[_ - self.degree]))
                else:
                    out[:, _] = (self.relu(x - self.knots[_ - self.degree - 1])) ** self.degree

        return out # bs, num_of_basis


class Dynamic_FC(nn.Module):
    def __init__(self, ind, outd, degree, knots, act='relu', isbias=1, islastlayer=0, has_dose=False):
        super(Dynamic_FC, self).__init__()
        self.ind = ind
        self.outd = outd
        self.degree = degree
        self.knots = knots

        self.islastlayer = islastlayer

        self.isbias = isbias

        self.spb = Truncated_power(degree, knots)
        self.d = self.spb.num_of_basis # num of basis
        if has_dose:
            self.weight = nn.Parameter(torch.rand(self.ind, self.outd, self.d), requires_grad=True)
            if self.isbias:
                self.bias = nn.Parameter(torch.rand(self.outd, self.d), requires_grad=True)
            else:
                self.bias = None
        else:
            self.weight = nn.Parameter(torch.rand(self.ind, self.outd), requires_grad=True)
            if self.isbias:
                self.bias = nn.Parameter(torch.rand(self.outd), requires_grad=True)
            else:
                self.bias = None

        if act == 'relu':
            self.act = nn.ReLU(inplace=True)
        elif act == 'tanh':
            self.act = nn.Tanh()
        elif act == 'sigmoid':
            self.act = nn.Sigmoid()
        else:
            self.act = None
        self.has_dose = has_dose

    def forward(self, x):
        # x: batch_size * (treatment, other feature)
        if self.has_dose:
            x_feature = x[:, 1:]
            x_treat = x[:, 0]
        else:
            x_feature = x

        x_feature_weight = torch.matmul(self.weight.T, x_feature.T).T # bs, outd, d
        if self.has_dose:
            x_treat_basis = self.spb.forward(x_treat).cuda() # bs, d
            x_treat_basis_ = torch.unsqueeze(x_treat_basis, 1)

        # x_feature_weight * x_treat_basis; bs, outd, d
            out = torch.sum(x_feature_weight * x_treat_basis_, dim=2) # bs, outd

            if self.isbias:
                out_bias = torch.matmul(self.bias, x_treat_basis.T).T
                out = out + out_bias
        
        else:
            out = x_feature_weight
        if self.act is not None:
            out = self.act(out)

        # concat the treatment for intermediate layer
        if self.has_dose:
            if not self.islastlayer:
                out = torch.cat((torch.unsqueeze(x_treat, 1), out), 1)
        # else:
        #     return out, x_feature
        return out

def comp_grid(y, num_grid):

    # L gives the lower index
    # U gives the upper index
    # inter gives the distance to the lower int

    U = torch.ceil(y * num_grid)
    inter = 1 - (U - y * num_grid)
    L = U - 1
    L += (L < 0).int()

    return L.int().tolist(), U.int().tolist(), inter

class Density_Block(nn.Module):
    def __init__(self, num_grid, ind, isbias=1):
        super(Density_Block, self).__init__()
        """
        Assume the variable is bounded by [0,1]
        the output grid: 0, 1/B, 2/B, ..., B/B; output dim = B + 1; num_grid = B
        """
        self.ind = ind
        self.num_grid = num_grid
        self.outd = num_grid + 1

        self.isbias = isbias

        self.weight = nn.Parameter(torch.rand(self.ind, self.outd), requires_grad=True)
        if self.isbias:
            self.bias = nn.Parameter(torch.rand(self.outd), requires_grad=True)
        else:
            self.bias = None

        self.softmax = nn.Softmax(dim=1)

    def forward(self, t, x):
        out = torch.matmul(x, self.weight)
        if self.isbias:
            out += self.bias
        out = self.softmax(out)

        x1 = list(torch.arange(0, x.shape[0]))
        L, U, inter = comp_grid(t, self.num_grid)

        L_out = out[x1, L]
        U_out = out[x1, U]

        out = L_out + (U_out - L_out) * inter

        return out

class Vcnet(nn.Module):
    def __init__(self, cfg_density, num_grid, cfg, degree, knots, num_t=1, has_dose=False, cont_treatment=False):
        super(Vcnet, self).__init__()
        """
        cfg_density: cfg for the density estimator; [(ind1, outd1, isbias1), 'act', ....]; the cfg for density estimator head is not included
        num_grid: how many grid used for the density estimator head
        """

        # cfg/cfg_density = [(ind1, outd1, isbias1, activation),....]
        self.cfg_density = cfg_density
        self.num_grid = num_grid
        self.num_t = num_t

        self.cfg = cfg
        self.degree = degree
        self.knots = knots

        # construct the density estimator
        density_blocks = []
        density_hidden_dim = -1
        for layer_idx, layer_cfg in enumerate(cfg_density):
            # fc layer
            if layer_idx == 0:
                # weight connected to feature
                self.feature_weight = nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2])
                density_blocks.append(self.feature_weight)
            else:
                density_blocks.append(nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2]))
            density_hidden_dim = layer_cfg[1]
            if layer_cfg[3] == 'relu':
                density_blocks.append(nn.ReLU(inplace=True))
            elif layer_cfg[3] == 'tanh':
                density_blocks.append(nn.Tanh())
            elif layer_cfg[3] == 'sigmoid':
                density_blocks.append(nn.Sigmoid())
            else:
                print('No activation')

        self.hidden_features = nn.Sequential(*density_blocks)
        self.linear = nn.Linear(64, 1)

        self.density_hidden_dim = density_hidden_dim
        self.density_estimator_head = Density_Block(self.num_grid, density_hidden_dim, isbias=1)

        # construct the dynamics network
        Q_net_ls = []
        self.cont_treatment = cont_treatment
        for i in range(num_t):
            blocks = []
            for layer_idx, layer_cfg in enumerate(cfg):
                # if layer_idx == 0:
                #     layer_ind = layer_cfg[0]
                #     if self.cont_treatment:
                #         layer_ind += 1
                #     if has_dose:
                #         layer_ind += 1
                #     layer_cfg[0] = layer_ind
                if layer_idx == len(cfg)-1: # last layer
                    last_layer = Dynamic_FC(layer_cfg[0], layer_cfg[1], self.degree, self.knots, act=layer_cfg[3], isbias=layer_cfg[2], islastlayer=1, has_dose=has_dose)
                else:
                    blocks.append(
                        Dynamic_FC(layer_cfg[0], layer_cfg[1], self.degree, self.knots, act=layer_cfg[3], isbias=layer_cfg[2], islastlayer=0, has_dose=has_dose))
            blocks.append(last_layer)
            Q_net_ls.append(nn.Sequential(*blocks))
            # if i == 0:
            #     self.Q0 = nn.Sequential(*blocks)
            # elif i == 1:
            #     self.Q1 = nn.Sequential(*blocks)
            # elif i == 2:
            #     self.Q2 = nn.Sequential(*blocks)
        self.Q_net_ls = torch.nn.ModuleList(Q_net_ls)
        self.has_dose = has_dose
        

    def forward(self, x, t, d=None, test=False):
        hidden = self.hidden_features(x)
        if self.has_dose:
            t_hidden = torch.cat((d.view(hidden.shape[0],1), hidden), 1)
        else:
            t_hidden = hidden
        
        if (not test) or self.cont_treatment:
            if self.cont_treatment:
                t_hidden = torch.cat((t.view(hidden.shape[0],1), t_hidden), 1)
                out = self.Q_net_ls[0](t_hidden)
                return t_hidden, out
            out = torch.zeros(x.shape[0], 1).cuda()
            for i in range(self.num_t):
                idx = list(set(list(torch.where(t == i)[0].cpu().numpy())))
                out[idx] = self.Q_net_ls[i](t_hidden[idx])
                # if i == 0:
                #     out[idx] = self.Q0(t_hidden[idx])
                # elif i == 1:
                #     out[idx] = self.Q1(t_hidden[idx])
                # elif i == 2:
                #     out[idx] = self.Q2(t_hidden[idx])
            return t_hidden, out
        else:
            full_out = []
            for i in range(self.num_t):
                full_out.append(self.Q_net_ls[i](t_hidden))
            
            full_out = torch.cat(full_out, dim=-1)
            out = full_out[torch.arange(len(full_out)), t.squeeze().long()]
            return t_hidden, out, full_out

    def _initialize_weights(self, initialiser):
        # TODO: maybe add more distribution for initialization
        initialiser = get_initialiser(initialiser)
        for m in self.modules():
            if isinstance(m, Dynamic_FC):
                initialiser(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                initialiser(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, Density_Block):
                initialiser(m.weight)
                if m.isbias:
                    m.bias.data.zero_()


class Vcnet_image(nn.Module):
    def __init__(self, cfg_density, num_grid, cfg, degree, knots, num_t=1, has_dose=False, cont_treatment=False):
        super(Vcnet_image, self).__init__()
        """
        cfg_density: cfg for the density estimator; [(ind1, outd1, isbias1), 'act', ....]; the cfg for density estimator head is not included
        num_grid: how many grid used for the density estimator head
        """
        self.img_emb = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.img_emb.visual_projection = nn.Linear(self.img_emb.visual_projection.in_features, cfg[0][0])
        # cfg/cfg_density = [(ind1, outd1, isbias1, activation),....]
        self.cfg_density = cfg_density
        self.num_grid = num_grid
        self.num_t = num_t

        self.cfg = cfg
        self.degree = degree
        self.knots = knots

        # construct the density estimator
        density_blocks = []
        density_hidden_dim = -1
        for layer_idx, layer_cfg in enumerate(cfg_density):
            # fc layer
            if layer_idx == 0:
                # weight connected to feature
                self.feature_weight = nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2])
                density_blocks.append(self.feature_weight)
            else:
                density_blocks.append(nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2]))
            density_hidden_dim = layer_cfg[1]
            if layer_cfg[3] == 'relu':
                density_blocks.append(nn.ReLU(inplace=True))
            elif layer_cfg[3] == 'tanh':
                density_blocks.append(nn.Tanh())
            elif layer_cfg[3] == 'sigmoid':
                density_blocks.append(nn.Sigmoid())
            else:
                print('No activation')

        self.hidden_features = nn.Sequential(*density_blocks)
        self.linear = nn.Linear(64, 1)

        self.density_hidden_dim = density_hidden_dim
        self.density_estimator_head = Density_Block(self.num_grid, density_hidden_dim, isbias=1)

        # construct the dynamics network
        cfg[0][0] *= 2
        cfg[0][1] *= 2
        cfg[1][0] *= 2
        Q_net_ls = []
        for i in range(num_t):
            blocks = []
            for layer_idx, layer_cfg in enumerate(cfg):
                if layer_idx == len(cfg)-1: # last layer
                    last_layer = Dynamic_FC(layer_cfg[0], layer_cfg[1], self.degree, self.knots, act=layer_cfg[3], isbias=layer_cfg[2], islastlayer=1)
                else:
                    blocks.append(
                        Dynamic_FC(layer_cfg[0], layer_cfg[1], self.degree, self.knots, act=layer_cfg[3], isbias=layer_cfg[2], islastlayer=0))
            blocks.append(last_layer)
            Q_net_ls.append(nn.Sequential(*blocks))
            # if i == 0:
            #     self.Q0 = nn.Sequential(*blocks)
            # elif i == 1:
            #     self.Q1 = nn.Sequential(*blocks)
            # elif i == 2:
            #     self.Q2 = nn.Sequential(*blocks)
        self.Q_net_ls = torch.nn.ModuleList(Q_net_ls)
        self.has_dose = has_dose
        self.cont_treatment = cont_treatment

    def forward(self,image, x, t, d=None, test=False):
        image_emb = self.img_emb.get_image_features(**image)
        hidden = self.hidden_features(x)
        if self.has_dose:
            t_hidden = torch.cat((d.view(hidden.shape[0],1), hidden), 1)
        else:
            t_hidden = hidden
        
        t_hidden = torch.cat([t_hidden, image_emb], dim=1)
        
        if (not test) or self.cont_treatment:
            out = torch.zeros(x.shape[0], 1).cuda()
            for i in range(self.num_t):
                idx = list(set(list(torch.where(t == i)[0].cpu().numpy())))
                out[idx] = self.Q_net_ls[i](t_hidden[idx])
                # if i == 0:
                #     out[idx] = self.Q0(t_hidden[idx])
                # elif i == 1:
                #     out[idx] = self.Q1(t_hidden[idx])
                # elif i == 2:
                #     out[idx] = self.Q2(t_hidden[idx])
            return out
        else:
            full_out = []
            for i in range(self.num_t):
                full_out.append(self.Q_net_ls[i](t_hidden))
            
            full_out = torch.cat(full_out, dim=-1)
            out = full_out[torch.arange(len(full_out)), t.view(-1).long()]
            return out, full_out

    def _initialize_weights(self, initialiser):
        # TODO: maybe add more distribution for initialization
        initialiser = get_initialiser(initialiser)
        for m in self.modules():
            if isinstance(m, Dynamic_FC):
                initialiser(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                initialiser(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, Density_Block):
                initialiser(m.weight)
                if m.isbias:
                    m.bias.data.zero_()


# Targeted Regularizer

class TR(nn.Module):
    def __init__(self, degree, knots):
        super(TR, self).__init__()
        self.spb = Truncated_power(degree, knots)
        self.d = self.spb.num_of_basis # num of basis
        self.weight = nn.Parameter(torch.rand(self.d), requires_grad=True).cuda()

    def forward(self, t):
        out = self.spb.forward(t).cuda()
        out = torch.matmul(out, self.weight)
        return out

    def _initialize_weights(self):
        #self.weight.data.normal_(0, 0.01)
        self.weight.data.zero_()

# ------------------------------------------ Drnet and Tarnet ------------------------------------------- #

class Treat_Linear(nn.Module):
    def __init__(self, ind, outd, act='relu', istreat=1, isbias=1, islastlayer=0):
        super(Treat_Linear, self).__init__()
        # ind does NOT include the extra concat treatment
        self.ind = ind
        self.outd = outd
        self.isbias = isbias
        self.istreat = istreat
        self.islastlayer = islastlayer

        self.weight = nn.Parameter(torch.rand(self.ind, self.outd), requires_grad=True)

        if self.isbias:
            self.bias = nn.Parameter(torch.rand(self.outd), requires_grad=True)
        else:
            self.bias = None

        if self.istreat:
            self.treat_weight = nn.Parameter(torch.rand(1, self.outd), requires_grad=True)
        else:
            self.treat_weight = None

        if act == 'relu':
            self.act = nn.ReLU(inplace=True)
        elif act == 'tanh':
            self.act = nn.Tanh()
        elif act == 'sigmoid':
            self.act = nn.Sigmoid()
        else:
            self.act = None

    def forward(self, x):
        # x: batch_size * (treatment, other feature)
        x_feature = x[:, 1:]
        x_treat = x[:, [0]]

        out = torch.matmul(x_feature, self.weight)

        if self.istreat:
            out = out + torch.matmul(x_treat, self.treat_weight)
        if self.isbias:
            out = out + self.bias

        if self.act is not None:
            out = self.act(out)

        if not self.islastlayer:
            out = torch.cat((x_treat, out), 1)
        # else:
        #     return out, x_feature

        return out

class Multi_head(nn.Module):
    def __init__(self, cfg, isenhance, h=1, has_dose=False):
        super(Multi_head, self).__init__()

        self.cfg = cfg # cfg does NOT include the extra dimension of concat treatment
        self.isenhance = isenhance  # set 1 to concat treatment every layer/ 0: only concat on first layer

        # we default set num of heads = 5
        l = 0.0
        h -= l
        self.pt = [l, l+h/5, l+h*2/5, l+h*3/5, l+h*4/5, l+h]

        self.outdim = -1
        # construct the predicting networks
        
        blocks = []
        for layer_idx, layer_cfg in enumerate(cfg):
            if layer_idx == len(cfg) - 1:  # last layer
                self.outdim = layer_cfg[1]
                if layer_idx == 0 or self.isenhance:
                    istreat = 1
                else:
                    istreat = 0
                last_layer = Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat, isbias=layer_cfg[2],
                                          islastlayer=1)
            else:
                if layer_idx == 0 or self.isenhance:
                    istreat = 1
                else:
                    istreat = 0
                blocks.append(Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat, isbias=layer_cfg[2],
                                           islastlayer=0))
        blocks.append(last_layer)
        self.Q1 = nn.Sequential(*blocks)
        self.has_dose = has_dose
        if has_dose:
            blocks = []
            for layer_idx, layer_cfg in enumerate(cfg):
                if layer_idx == len(cfg) - 1:  # last layer
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    last_layer = Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat,
                                            isbias=layer_cfg[2],
                                            islastlayer=1)
                else:
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    blocks.append(
                        Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat, isbias=layer_cfg[2],
                                    islastlayer=0))
            blocks.append(last_layer)
            self.Q2 = nn.Sequential(*blocks)

            blocks = []
            for layer_idx, layer_cfg in enumerate(cfg):
                if layer_idx == len(cfg) - 1:  # last layer
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    last_layer = Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat,
                                            isbias=layer_cfg[2],
                                            islastlayer=1)
                else:
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    blocks.append(
                        Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat, isbias=layer_cfg[2],
                                    islastlayer=0))
            blocks.append(last_layer)
            self.Q3 = nn.Sequential(*blocks)

            blocks = []
            for layer_idx, layer_cfg in enumerate(cfg):
                if layer_idx == len(cfg) - 1:  # last layer
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    last_layer = Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat,
                                            isbias=layer_cfg[2],
                                            islastlayer=1)
                else:
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    blocks.append(
                        Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat, isbias=layer_cfg[2],
                                    islastlayer=0))
            blocks.append(last_layer)
            self.Q4 = nn.Sequential(*blocks)

            blocks = []
            for layer_idx, layer_cfg in enumerate(cfg):
                if layer_idx == len(cfg) - 1:  # last layer
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    last_layer = Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat,
                                            isbias=layer_cfg[2],
                                            islastlayer=1)
                else:
                    if layer_idx == 0 or self.isenhance:
                        istreat = 1
                    else:
                        istreat = 0
                    blocks.append(
                        Treat_Linear(layer_cfg[0], layer_cfg[1], act=layer_cfg[3], istreat=istreat, isbias=layer_cfg[2],
                                    islastlayer=0))
            blocks.append(last_layer)
            self.Q5 = nn.Sequential(*blocks)

    def forward(self, x):
        # x = [treatment, features]
        if not self.has_dose:
            out = self.Q1(x)
            return out
        
        out = torch.zeros(x.shape[0], self.outdim).cuda()
        t = x[:, 0]

        idx1 = list(set(list(torch.where(t >= self.pt[0])[0].cpu().numpy())) & set(torch.where(t < self.pt[1])[0].cpu().numpy()))
        idx2 = list(set(list(torch.where(t >= self.pt[1])[0].cpu().numpy())) & set(torch.where(t < self.pt[2])[0].cpu().numpy()))
        idx3 = list(set(list(torch.where(t >= self.pt[2])[0].cpu().numpy())) & set(torch.where(t < self.pt[3])[0].cpu().numpy()))
        idx4 = list(set(list(torch.where(t >= self.pt[3])[0].cpu().numpy())) & set(torch.where(t < self.pt[4])[0].cpu().numpy()))
        idx5 = list(set(list(torch.where(t >= self.pt[4])[0].cpu().numpy())) & set(torch.where(t <= self.pt[5])[0].cpu().numpy()))

        if idx1:
            out1 = self.Q1(x[idx1, :])
            out[idx1, :] = out[idx1, :] + out1

        if idx2:
            out2 = self.Q2(x[idx2, :])
            out[idx2, :] = out[idx2, :] + out2

        if idx3:
            out3 = self.Q3(x[idx3, :])
            out[idx3, :] = out[idx3, :] + out3

        if idx4:
            out4 = self.Q4(x[idx4, :])
            out[idx4, :] = out[idx4, :] + out4

        if idx5:
            out5 = self.Q5(x[idx5, :])
            out[idx5, :] = out[idx5, :] + out5

        return out


class Drnet(nn.Module):
    def __init__(self, cfg_density, num_grid, cfg, isenhance, h=1, att_layers=0, num_t=2, has_dose=False, cont_treatment=False):
        super(Drnet, self).__init__()

        self.cfg_density = cfg_density
        self.num_grid = num_grid
        self.cfg = cfg
        self.isenhance = isenhance
        self.att_layers = att_layers
        self.h = h
        self.num_t = num_t

        # assert(num_t < 4)

        # construct the density estimator
        density_blocks = []
        density_hidden_dim = -1
        for layer_idx, layer_cfg in enumerate(cfg_density):
            # fc layer
            if layer_idx == 0:
                # weight connected to feature
                self.feature_weight = nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2])
                density_blocks.append(self.feature_weight)
            else:
                density_blocks.append(nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2]))
            density_hidden_dim = layer_cfg[1]
            if layer_cfg[3] == 'relu':
                density_blocks.append(nn.ReLU(inplace=True))
            elif layer_cfg[3] == 'tanh':
                density_blocks.append(nn.Tanh())
            elif layer_cfg[3] == 'sigmoid':
                density_blocks.append(nn.Sigmoid())
            else:
                print('No activation')

        self.hidden_features = nn.Sequential(*density_blocks)
        self.linear = nn.Linear(64, 1)
        # multi-head outputs blocks
        if has_dose:
            cfg[0][0] += 1
        self.Q_net_ls = torch.nn.ModuleList([Multi_head(cfg, isenhance, h=h, has_dose=has_dose) for _ in range(num_t)])
        
        # for i in range(num_t):
        #     if i == 0:
        #         self.Q0 = Multi_head(cfg, isenhance, h=h, has_dose=has_dose)
        #     elif i == 1:
        #         self.Q1 = Multi_head(cfg, isenhance, h=h, has_dose=has_dose)
        #     elif i == 2:
        #         self.Q2 = Multi_head(cfg, isenhance, h=h, has_dose=has_dose)
        
        self.has_dose = has_dose
        self.cont_treatment = cont_treatment

    def forward(self, x, t, d=None, test=False):
        hidden = self.hidden_features(x)
        if self.has_dose:
            t_hidden = torch.cat((d.view(hidden.shape[0],1), hidden), 1)
        else:
            t_hidden = hidden
        if (not test) or (self.cont_treatment):
            
            out = torch.zeros(x.shape[0], 1).cuda()
            if self.cont_treatment:
                t_hidden = torch.cat((t.view(hidden.shape[0],1), t_hidden), 1)
                out = self.Q_net_ls[0](t_hidden)
            else:
                for i in range(self.num_t):
                    idx = list(set(list(torch.where(t == i)[0].cpu().numpy())))
                    # out[idx] = self.Q_net_ls[i](t_hidden[idx])
                    curr_t_hidden = t_hidden[idx]
                    out[idx] = self.Q_net_ls[i](torch.cat([torch.ones(len(curr_t_hidden), 1).to(x.device)*i, curr_t_hidden], dim=-1))
                # if i == 0:
                #     out[idx] = self.Q0(t_hidden[idx])
                # elif i == 1:
                #     out[idx] = self.Q1(t_hidden[idx])
                # elif i == 2:
                #     out[idx] = self.Q2(t_hidden[idx])
            return t_hidden, out
        else:
            full_out = []
            for i in range(self.num_t):
                full_out.append(self.Q_net_ls[i](torch.cat([torch.ones(len(t_hidden), 1).to(x.device)*i, t_hidden], dim=-1)))
                
            full_out_tensor = torch.cat(full_out, dim=-1)
            out = full_out_tensor[torch.arange(len(full_out_tensor)), t.view(-1).long()]
            return t_hidden, out, full_out_tensor
            
            

    def _initialize_weights(self, initialiser=None):
        # TODO: maybe add more distribution for initialization
        for m in self.modules():
            if isinstance(m, Dynamic_FC):
                m.weight.data.normal_(0, 1.)
                if m.isbias:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, Density_Block):
                m.weight.data.normal_(0, 0.01)
                if m.isbias:
                    m.bias.data.zero_()
                    
                    
class Drnet_image(nn.Module):
    def __init__(self, cfg_density, num_grid, cfg, isenhance, h=1, att_layers=0, num_t=2, has_dose=False, cont_treatment=False):
        super(Drnet_image, self).__init__()

        self.cfg_density = cfg_density
        self.num_grid = num_grid
        
        self.isenhance = isenhance
        self.att_layers = att_layers
        self.h = h
        self.num_t = num_t

        # assert(num_t < 4)
        
        self.img_emb = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.img_emb.visual_projection = nn.Linear(self.img_emb.visual_projection.in_features, cfg[0][0])
        
        self.cfg = cfg
        # construct the density estimator
        density_blocks = []
        density_hidden_dim = -1
        for layer_idx, layer_cfg in enumerate(cfg_density):
            # fc layer
            if layer_idx == 0:
                # weight connected to feature
                self.feature_weight = nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2])
                density_blocks.append(self.feature_weight)
            else:
                density_blocks.append(nn.Linear(in_features=layer_cfg[0], out_features=layer_cfg[1], bias=layer_cfg[2]))
            density_hidden_dim = layer_cfg[1]
            if layer_cfg[3] == 'relu':
                density_blocks.append(nn.ReLU(inplace=True))
            elif layer_cfg[3] == 'tanh':
                density_blocks.append(nn.Tanh())
            elif layer_cfg[3] == 'sigmoid':
                density_blocks.append(nn.Sigmoid())
            else:
                print('No activation')

        self.hidden_features = nn.Sequential(*density_blocks)
        self.linear = nn.Linear(64, 1)
        # multi-head outputs blocks
        cfg[0][0] *= 2
        cfg[0][1] *= 2
        cfg[1][0] *= 2
        self.Q_net_ls = torch.nn.ModuleList([Multi_head(cfg, isenhance, h=h, has_dose=has_dose) for _ in range(num_t)])
        
        # for i in range(num_t):
        #     if i == 0:
        #         self.Q0 = Multi_head(cfg, isenhance, h=h, has_dose=has_dose)
        #     elif i == 1:
        #         self.Q1 = Multi_head(cfg, isenhance, h=h, has_dose=has_dose)
        #     elif i == 2:
        #         self.Q2 = Multi_head(cfg, isenhance, h=h, has_dose=has_dose)
        
        self.has_dose = has_dose
        self.cont_treatment = cont_treatment

    def forward(self, images, x, t, d=None, test=False):
        image_emb = self.img_emb.get_image_features(**images)
        
        hidden = self.hidden_features(x)
        if self.has_dose:
            t_hidden = torch.cat((d.view(hidden.shape[0],1), hidden), 1)
        else:
            t_hidden = hidden
        t_hidden = torch.cat([t_hidden, image_emb], dim=1)
        
        if not test:
            
            out = torch.zeros(x.shape[0], 1).cuda()
            
            for i in range(self.num_t):
                idx = list(set(list(torch.where(t == i)[0].cpu().numpy())))
                # out[idx] = self.Q_net_ls[i](t_hidden[idx])
                curr_t_hidden = t_hidden[idx]
                out[idx] = self.Q_net_ls[i](torch.cat([torch.ones(len(curr_t_hidden), 1).to(x.device)*i, curr_t_hidden], dim=-1))
                # if i == 0:
                #     out[idx] = self.Q0(t_hidden[idx])
                # elif i == 1:
                #     out[idx] = self.Q1(t_hidden[idx])
                # elif i == 2:
                #     out[idx] = self.Q2(t_hidden[idx])
            return out
        else:
            full_out = []
            for i in range(self.num_t):
                full_out.append(self.Q_net_ls[i](torch.cat([torch.ones(len(t_hidden), 1).to(x.device)*i, t_hidden], dim=-1)))
                
            full_out_tensor = torch.cat(full_out, dim=-1)
            out = full_out_tensor[torch.arange(len(full_out_tensor)), t.view(-1).long()]
            return out, full_out_tensor
            
            

    def _initialize_weights(self, initialiser=None):
        # TODO: maybe add more distribution for initialization
        for m in self.modules():
            if isinstance(m, Dynamic_FC):
                m.weight.data.normal_(0, 1.)
                if m.isbias:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, Density_Block):
                m.weight.data.normal_(0, 0.01)
                if m.isbias:
                    m.bias.data.zero_()