# Standard Library Imports
import os
import math

# Third-Party Library Imports
import torch
from tqdm import tqdm

# Local Imports
from utils import save_dict

def convolution_size(
    given_size, num_layers, kernel_sizes, padding=0, strides=1, dilations=1,
    inverse=False
):
    """
    Computes the size of the convolutional output after applying several layers
    of convolution to an input of a given size. Alternatively, this can also
    compute the size of a convolutional input needed to create the given size
    for an output.
    Arguments:
        `given_size`: the size of an input sequence, or the size of a desired
            output sequence
        `num_layers`: number of convolutional layers to apply
        `kernel_sizes`: array of kernel sizes, to be applied in order; can also
            be an integer, which is the same kernel size for all layers
        `padding`: array of padding amounts, with each value being the amount of
            padding on each side of the input at each layer; can also be an
            integer, which is the same padding for all layers
        `strides`: array of stride values, with each value being the stride
            at each layer; can also be an integer, which is the same stride for
            all layers
        `dilations`: array of dilation values, with each value being the
            dilation at each layer; can also be an integer, which is the same
            dilation for all layers
        `inverse`: if True, computes the size of input needed to generate an
            output of size `given_size`
    Returns the size of the sequence after convolutional layers of these
    specifications are applied in order.
    """
    if type(kernel_sizes) is int:
        kernel_sizes = [kernel_sizes] * num_layers
    else:
        assert len(kernel_sizes) == num_layers
    if type(padding) is int:
        padding = [padding] * num_layers
    else:
        assert len(padding) == num_layers
    if type(strides) is int:
        strides = [strides] * num_layers
    else:
        assert len(strides) == num_layers
    if type(dilations) is int:
        dilations = [dilations] * num_layers
    else:
        assert len(dilations) == num_layers

    size = given_size

    if not inverse:
        for i in range(num_layers):
            size = int(
                (size + (2 * padding[i]) - (dilations[i] * (kernel_sizes[i] - 1)) \
                 - 1) / strides[i]
            ) + 1
    else:
        for i in range(num_layers - 1, -1, -1):
            size = (strides[i] * (size - 1)) - (2 * padding[i]) + \
                   (dilations[i] * (kernel_sizes[i] - 1)) + 1
    return size  


class classifier(torch.nn.Module):
    def __init__(
        self, num_conv_layers=3, conv_filter_sizes=[10, 5, 5],
        conv_filter_nums=[8, 8, 8], max_pool_size=40, max_pool_stride=40,
        num_fc_layers=2, fc_sizes=[10, 5], batch_norm=True, input_length=500,
        input_dim=4
    ):
        """
        Initializes a standard CNN architecture for regulatory genomics.
        Arguments:
            `num_conv_layers`: number of convolutional layers
            `conv_filter_sizes`: size of filters in each layer, a list
            `conv_filter_nums`: number of filters in each layer, a list
            `max_pool_size`: size of max-pool filter
            `max_pool_stride`: stride for max-pool filter
            `num_fc_layers`: number of linear layers at the end
            `fc_sizes`: number of hidden units in each linear layer
            `batch_norm`: whether or not to apply batch normalization
            `input_length`: length of input sequence
            `input_dim`: dimension of input sequence (e.g. 4 for DNA)
        """
        super().__init__()

        assert len(conv_filter_sizes) == num_conv_layers
        assert len(conv_filter_nums) == num_conv_layers
        assert len(fc_sizes) == num_fc_layers

        # Define the convolutional layers
        depths = [input_dim] + conv_filter_nums
        self.conv_layers = torch.nn.ModuleList()
        for i in range(num_conv_layers):
            layer = [
                torch.nn.Conv1d(depths[i], depths[i + 1], conv_filter_sizes[i]),
                torch.nn.ReLU()
            ]
            if batch_norm:
                layer.append(torch.nn.BatchNorm1d(depths[i + 1]))
            self.conv_layers.append(torch.nn.Sequential(*layer))

        # Define the max pooling layer
        self.max_pool_layer = torch.nn.MaxPool1d(
            max_pool_size, stride=max_pool_stride
        )

        # Compute size of the pooling output
        conv_output_size = convolution_size(
            input_length, num_conv_layers, conv_filter_sizes
        )
        pool_output_size = math.floor(
            (conv_output_size - (max_pool_size - 1) - 1) / max_pool_stride
        ) + 1
        pool_output_depth = conv_filter_nums[-1]
        
        # Define the fully connected layers
        dims = [pool_output_size * pool_output_depth] + fc_sizes
        self.fc_layers = torch.nn.ModuleList()
        for i in range(num_fc_layers):
            layer = [
                torch.nn.Linear(dims[i], dims[i + 1]),
                torch.nn.ReLU()
            ]
            if batch_norm:
                layer.append(
                    torch.nn.BatchNorm1d(dims[i + 1])
                )
            self.fc_layers.append(torch.nn.Sequential(*layer))

        # Map last fully connected layer to final outputs
        self.last_fc_layer = torch.nn.Linear(fc_sizes[-1], 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, input_seq, return_interims=False):
        """
        Runs the forward pass of the model.
        Arguments:
            `input_seq`: a B x L x D tensor of the input sequence, where B is
                the batch dimension, L is the sequence length, and D is the
                feature dimension
            `return_interims`: if True, also return a dictionary of
                intermediates
        Returns a B x 1 tensor containing the predicted probabilities for each
        input sequence. If `return_interims` is True, also returns a dictionary
        containing a B x L' x F tensor of the first-layer convolutional
        activations.
        """
        # Convolutional layers
        conv_out = torch.transpose(input_seq, 1, 2)  # Shape: B x D x L
        for i, conv_layer in enumerate(self.conv_layers):
            if i == 0 and return_interims:
                conv_out = conv_layer[1](conv_layer[0](conv_out))
                conv_acts_cache = conv_out
                for j in range(2, len(conv_layer)):
                    conv_out = conv_layer[j](conv_out)
                interims = {"conv_acts": torch.transpose(conv_acts_cache, 1, 2)}
            else:
                conv_out = conv_layer(conv_out)

        # Max pooling
        pool_out = self.max_pool_layer(conv_out)  # Shape: B x D' x L'
        pool_out = pool_out.view(len(pool_out), -1)  # Shape: B x D'L'

        # Linear layers
        fc_out = pool_out
        for fc_layer in self.fc_layers:
            fc_out = fc_layer(fc_out)

        out = self.sigmoid(self.last_fc_layer(fc_out))  # Shape: B x 1
        
        if return_interims:
            return out, interims
        else:
            return out
   
    def loss(self, pred_probs, true_vals):
        """
        Computes total loss value for the predicted probabilities given the true
        values or probabilities.
        Arguments:
            `pred_probs`: a B x 1 tensor of predicted probabilities
            `true_vals`: a B x 1 tensor of binary labels or true probabilities
        Returns a B x 1 tensor of loss values.
        """ 
        return torch.nn.functional.binary_cross_entropy(
            pred_probs, true_vals, reduction="mean"
        )

    def train_and_validate(self, dataloaders, optimizer, save_path, num_epochs=20, device='cuda', real_data=False):
        self.to(device)
        ops = ["train", "val"]
        best_acc = 0.0

        for epoch in range(num_epochs):

            print(f"Epoch [{epoch+1}/{num_epochs}]")
            
            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                elif op == "val":
                    torch.set_grad_enabled(False)
                    self.eval()

                loader = dataloaders[op]

                running_loss = 0.0
                running_acc = 0.0
                if real_data == True:
                    loader.dataset.on_epoch_start()

                for _, data in enumerate(tqdm(loader)):
                    x, y = data
                    x = x.float().to(device)
                    y = y.float().to(device)
                    pred_probs = self(x).squeeze(-1)

                    # Compute metrics
                    loss = self.loss(pred_probs, y)
                    y_hat = (pred_probs >= 0.5)*1.0
                    acc = torch.sum(y_hat == y)/len(y)
                    running_loss += loss.item()
                    running_acc += acc.item()

                    if op == "train":
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                epoch_loss = running_loss/len(loader)
                epoch_acc = running_acc/len(loader)

                if op == "train":
                    print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")

                if op == "val":
                    print(f"Val Loss: {epoch_loss:.4f}, Val Acc: {epoch_acc:.4f}")
                    results_dict = {
                        "val_loss": epoch_loss,
                        "val_acc": epoch_acc,
                        "epoch": epoch,
                    }

                    if epoch_acc >= best_acc:
                        best_acc = epoch_acc
                        print("Saving new model!")
                        save_dict(results_dict, save_path, "results.txt")
                        torch.save(self.state_dict(), os.path.join(save_path, "clf.pt"))
                


