# Standard Library Imports
import os
import wandb
from copy import deepcopy

# Third Party Library Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# Local Imports
from layers import encoder
from utils import sample_masked_X, save_dict


class mex(nn.Module):
    def __init__(self, clf, D, bkdg_seqs, init_mask="ones", freeze=True):
        super(mex, self).__init__()

        # Background seqs
        self.B = bkdg_seqs

        # Distance matrix
        self.D = D

        # Deepcopy classifier to explain
        self.clf = deepcopy(clf)
        if self.clf.training:
            self.clf.eval()
            print(
                "Warning: Classifier was not in evaluation mode. Switched to evaluation mode."
            )

        # Freeze weights of classifier
        for param in self.clf.parameters():
            param.requires_grad = False
        print("Froze the classifier parameters")

        # Copy the original clf encoder and freeze weights if necessary
        if freeze == True:
            # Freezing encoder weights
            self.encoder = deepcopy(clf.encoder)
            for param in self.encoder.parameters():
                param.requires_grad = False
        else: 
            self.encoder = encoder(in_channels=4)

        # Define layers for w and sigma
        self.num_layers = 1
        self.hidden_size = 500

        # w layers
        self.w_fc = nn.ModuleList(
            [
                nn.Linear(self.hidden_size, self.hidden_size)
                for _ in range(self.num_layers)
            ]
        )
        self.w_bn = nn.ModuleList([nn.BatchNorm1d(1) for _ in range(self.num_layers)])

        self.w_last_fc_layer = nn.Linear(self.hidden_size, self.hidden_size)
        self.w_bn_last = nn.BatchNorm1d(1)

        # sigma layers
        self.sigma_fc = nn.ModuleList(
            [
                nn.Linear(self.hidden_size, self.hidden_size)
                for _ in range(self.num_layers)
            ]
        )
        self.sigma_bn = nn.ModuleList(
            [nn.BatchNorm1d(1) for _ in range(self.num_layers)]
        )

        self.sigma_last_fc_layer = nn.Linear(self.hidden_size, self.hidden_size)
        self.sigma_bn_last = nn.BatchNorm1d(1)

        # additional layers
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()

        self.init_mask = init_mask
        self._initialize_weights()

    def forward(self, x):

        # Compute encoded x
        out = self.encoder(x)

        # w computation
        w_out = out
        for fc, bn in zip(self.w_fc, self.w_bn):
            w = fc(w_out)
            w = bn(w)
            w_out = self.relu(w)

        w_out = self.w_last_fc_layer(w_out)
        w_out = self.w_bn_last(w_out)

        # Compute sigma
        sigma_out = out
        for fc, bn in zip(self.sigma_fc, self.sigma_bn):
            sigma = fc(sigma_out)
            sigma = bn(sigma)
            sigma_out = self.relu(sigma)

        sigma_out = self.sigma_last_fc_layer(sigma_out)
        sigma_out = self.sigma_bn_last(sigma_out)
        sigma_out = self.softplus(sigma_out)

        return w_out, sigma_out

    def _initialize_weights(self):
        if self.init_mask == "ones":

            # Initialize weights so that w is 1s
            nn.init.constant_(self.w_last_fc_layer.weight, 0)
            nn.init.constant_(self.w_last_fc_layer.bias, 1)

            # Initialize weights so that sigma is 1s
            nn.init.constant_(self.sigma_last_fc_layer.weight, 0)
            nn.init.constant_(self.sigma_last_fc_layer.bias, np.log(np.exp(1) - 1))

    def _get_mask(self, w, sigma):

        # Generate mask using w and sigma
        w_t = w.transpose(1, 2)
        sigma_t = sigma.transpose(1, 2)
        W = w_t * torch.exp(-((self.D) ** 2) / sigma_t)
        masks = torch.sigmoid(torch.sum(W, dim=1))

        return masks

    def compute_losses(self, f_xS, f_xSc, y_hat, alpha, masks, sigma):

        err_S = 1 - torch.sum((f_xS >= 0.5) * 1.0 == y_hat) / len(y_hat)
        err_Sc = 1 - torch.sum((f_xSc >= 0.5) * 1.0 == y_hat) / len(y_hat)

        suff = F.binary_cross_entropy(f_xS, y_hat, reduction="mean")
        necc = -F.binary_cross_entropy(f_xSc, y_hat, reduction="mean")
        obj_loss = alpha * suff + (1 - alpha) * necc

        l1_loss = torch.mean(torch.mean(torch.abs(masks), dim=1))
        sm_loss = -torch.mean(torch.mean(torch.log(sigma), dim=1))
        tv_loss = torch.mean(
            torch.mean(torch.abs((masks[:, :-1, :] - masks[:, 1:, :])), dim=1)
        )

        return obj_loss, l1_loss, sm_loss, tv_loss, err_S, err_Sc

    def train_and_validate(
        self,
        dataloaders,
        optimizer,
        save_path,
        alpha=1,
        l1_mult=1,
        sm_mult=1,
        num_epochs=20,
        num_bkgd_samples=100,
        device="cuda",
        log_frac=0.25,
        wandb_log=False,
    ):

        if wandb_log == True:
            wandb.init(project="alpha_" + str(alpha) + "_mex_training")

        # Log step for training
        log_step = int(len(dataloaders["train"]) * log_frac)

        ops = ["train", "val"]
        best_explainer_loss = 1000

        for epoch in range(num_epochs):

            indices = np.random.choice(len(self.B), num_bkgd_samples, replace=False)
            bkgd_samples = torch.tensor(self.B[indices], device=device)

            print(f"Epoch [{epoch+1}/{num_epochs}]")

            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                else:
                    torch.set_grad_enabled(False)
                    self.eval()

                dataloader = dataloaders[op]

                running_total_loss = 0.0
                running_obj_loss = 0.0
                running_err_S = 0.0
                running_err_Sc = 0.0
                running_l1_loss = 0.0
                running_sm_loss = 0.0
                running_tv_loss = 0.0

                for i, data in enumerate(tqdm(dataloader)):

                    # Load sequences
                    x, _ = data
                    x = x.float().to(device)

                    # Compute original prediction
                    f_x = (self.clf(x)).squeeze(-1).detach()
                    y_hat = (f_x >= 0.5) * 1.0

                    # Compute w, sigma, and mask via explainer
                    w, sigma = self(x)
                    masks = self._get_mask(w, sigma).unsqueeze(-1)

                    # Generate samples x_S and x_Sc using masks
                    x_S, x_Sc = sample_masked_X(x, masks, bkgd_samples)

                    # Evaluate f_x, f_xS, and f_xSc
                    f_xS = (
                        self.clf(x_S.view(-1, 500, 4))
                        .view(-1, num_bkgd_samples)
                        .mean(-1)
                    )
                    f_xSc = (
                        self.clf(x_Sc.view(-1, 500, 4))
                        .view(-1, num_bkgd_samples)
                        .mean(-1)
                    )

                    # Evaluate losses
                    obj_loss, l1_loss, sm_loss, tv_loss, err_S, err_Sc = (
                        self.compute_losses(f_xS, f_xSc, y_hat, alpha, masks, sigma)
                    )
                    total_loss = obj_loss + l1_mult * l1_loss + sm_mult * sm_loss

                    running_obj_loss += obj_loss.item()
                    running_err_S += err_S.item()
                    running_err_Sc += err_Sc.item()
                    running_l1_loss += l1_loss.item()
                    running_sm_loss += sm_loss.item()
                    running_tv_loss += tv_loss.item()
                    running_total_loss += total_loss.item()

                    if op == "train":
                        optimizer.zero_grad()
                        total_loss.backward()
                        optimizer.step()

                        if (i + 1) % log_step == 0:

                            results_dict = {
                                "train_total_loss": running_total_loss / log_step,
                                "train_obj_loss": running_obj_loss / log_step,
                                "train_l1_loss": running_l1_loss / log_step,
                                "train_sm_loss": running_sm_loss / log_step,
                                "train_tv_loss": running_tv_loss / log_step,
                            }

                            for key, value in results_dict.items():
                                print(f"{key}: {value}")

                            if wandb_log == True:
                                wandb.log(results_dict)

                            running_total_loss = 0.0
                            running_err_S = 0.0
                            running_err_Sc = 0.0
                            running_obj_loss = 0.0
                            running_l1_loss = 0.0
                            running_sm_loss = 0.0

                if op == "val":
                    results_dict = {
                        "val_total_loss": running_total_loss / len(dataloader),
                        "val_objective_loss": running_obj_loss / len(dataloader),
                        "val_err_S": running_err_S / len(dataloader),
                        "val_err_Sc": running_err_Sc / len(dataloader),
                        "val_l1_loss": running_l1_loss / len(dataloader),
                        "val_sm_loss": running_sm_loss / len(dataloader),
                        "val_tv_loss": running_tv_loss / len(dataloader),
                        "epoch": epoch + 1,
                    }

                    for key, value in results_dict.items():
                        if key != "epoch":
                            print(f"{key}, {value}")

                    if wandb_log == True:
                        wandb.log(results_dict)

                    stopping_loss = (
                        alpha * results_dict["val_err_S"]
                        - (1 - alpha) * results_dict["val_err_Sc"]
                        + results_dict["val_l1_loss"]
                        + results_dict["val_tv_loss"]
                    )
                    if stopping_loss <= best_explainer_loss:
                        best_explainer_loss = stopping_loss
                        if save_path is not None:
                            print("Saving new model")
                            torch.save(
                                self.state_dict(),
                                os.path.join(os.path.join(save_path, "explainer.pt")),
                            )
                            # Save results and params used
                            save_dict(results_dict, save_path, "results.txt")
