import torch
import torchvision
from torch import nn

class ReluLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(ReluLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        self.fc_x = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.fc_h = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, hidden, cell):
        gate_input = self.fc_x(input) + self.fc_h(hidden)
        i, f, o, g = torch.split(gate_input, self.hidden_size, dim=1)
        i = self.sigmoid(i)
        f = self.sigmoid(f)
        o = self.sigmoid(o)
        g = self.relu(g)

        new_cell = f * cell + i * g
        new_hidden = o * self.relu(new_cell)
        return new_hidden, new_cell


class Resnet50LSTM(nn.Module):
    def __init__(self, num_classes, drop_rate=0.5,
                 num_mid_fc=2, fc_precede_mean=True, resnet_feat_type='layer4'):
        super(Resnet50LSTM, self).__init__()
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.num_mid_fc = num_mid_fc
        self.fc_precede_mean = fc_precede_mean
        self.resnet_feat_type = resnet_feat_type

        self.resnet = torchvision.models.resnet50(pretrained=True)
        if self.resnet_feat_type == 'layer4':
            self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
            self.num_feat = 2048
        elif self.resnet_feat_type == 'layer3':
            self.resnet = nn.Sequential(*list(self.resnet.children())[:-3])
            self.num_feat = 1024

        if self.num_mid_fc == 2:
            self.fc6 = nn.Linear(self.num_feat, 1024, bias=True)
            self.fc7 = nn.Linear(1024, 512, bias=True)
        elif self.num_mid_fc == 1:
            self.fc6 = nn.Linear(self.num_feat, 512, bias=True)

        self.dropout = nn.Dropout(p=self.drop_rate)
        self.relu = nn.ReLU(True)
        self.lstm1 = ReluLSTMCell(512, 256, bias=True)
        self.fc8_final = nn.Linear(256, self.num_classes, bias=True)

    def forward(self, video_tensor):
        # video_tensor: batch_size x 3 x num_f x H x W
        batch_size = video_tensor.shape[0]
        seq_len = video_tensor.shape[2]

        h_tops = []
        h_top = torch.zeros(batch_size, 256).to(video_tensor.device)
        c_top = torch.zeros(batch_size, 256).to(video_tensor.device)
        for fidx in range(seq_len):
            img_tensor = video_tensor[:, :, fidx, :, :]
            y = self.resnet(img_tensor)
            if self.resnet_feat_type != 'layer4':
                # print(y.shape)
                y = y.mean(dim=3, keepdim=True).mean(dim=2, keepdim=True)
            y = y.view(batch_size, -1)

            if self.num_mid_fc == 2:
                y = self.dropout(self.relu(self.fc6(y)))
                y = self.dropout(self.relu(self.fc7(y)))
            elif self.num_mid_fc == 1:
                y = self.dropout(self.relu(self.fc6(y)))

            h_top, c_top = self.lstm1(y, h_top.detach(), c_top.detach())
            y = self.dropout(h_top)
            h_tops.append(y.clone())
        h_tops = torch.stack(h_tops, dim=1)  # bs x numf x C
        if self.fc_precede_mean:
            pred = torch.mean(self.fc8_final(h_tops), dim=1)  # bs x num_cls
        else:
            pred = self.fc8_final(torch.mean(h_tops, dim=1))
        return pred







