# Standard Library Imports
import os
from copy import deepcopy

# Third-Party Library Imports
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# Local Imports
from layers import encoder
from utils import save_dict

class scrambler(nn.Module):
    def __init__(self, clf):
        super(scrambler, self).__init__()
        
        # Load classifier to explain
        self.clf = deepcopy(clf)
        if self.clf.training:
            self.clf.eval()
            print(
                "Warning: Classifier was not in evaluation mode. Switched to evaluation mode."
            )

        # Freeze weights of classifier
        for param in self.clf.parameters():
            param.requires_grad = False
        print("Froze the classifier parameters")

        self.encoder = encoder(in_channels=4)
        self.softplus = nn.Softplus()

        # Background PSSM
        self.B = torch.ones((1, 500, 4)) * 0.25

    def forward(self, x):

        out = self.encoder(x).squeeze(1)
        out = self.softplus(out)

        return out
    
    def inverse_bce(self, input, target):
        with torch.no_grad():
            eps = 0.01
            input = torch.clamp(input, eps, 1.0 - eps)
        loss = - (target * torch.log(1-input) + (1 - target) * torch.log(input))
        return loss.mean()

    def compute_losses(
        self, f_xS, f_xSc, y_hat, alpha, log_pssm_S, log_pssm_Sc, t_bits
    ):

        err_S = 1 - torch.sum((f_xS >= 0.5) * 1.0 == y_hat) / len(y_hat)
        err_Sc = 1 - torch.sum((f_xSc >= 0.5) * 1.0 == y_hat) / len(y_hat)

        suff = F.binary_cross_entropy(f_xS, y_hat, reduction="mean")
        # necc = -F.binary_cross_entropy(f_xSc, y_hat, reduction="mean")
        necc = self.inverse_bce(f_xSc, y_hat)
        obj_loss = alpha * suff + (1 - alpha) * necc

        if alpha == 1:
            kl_divs = F.kl_div(log_pssm_S, self.B, reduction="none")
        elif alpha == 0:
            kl_divs = F.kl_div(log_pssm_Sc, self.B, reduction="none")
        mean_kl = kl_divs.sum(dim=-1).mean()
        kl_loss = (t_bits - mean_kl) ** 2

        return obj_loss, kl_loss, err_S, err_Sc

    def train_and_validate(
        self,
        dataloaders,
        optimizer,
        save_path,
        alpha=1,
        kl_mult=1,
        t_bits=0.01,
        num_epochs=20,
        num_bkgd_samples=32,
        device="cuda",
        log_frac=0.25,
        wandb_log=False,
    ):
        if wandb_log == True:
            wandb.init(project="alpha_" + str(alpha) + "_scrambler_training")

        # Log step for training
        log_step = int(len(dataloaders["train"]) * log_frac)

        self.to(device)
        self.B = self.B.to(device)
        ops = ["train", "val"]
        best_scrambler_loss = 1000

        for epoch in range(num_epochs):

            print(f"Epoch [{epoch+1}/{num_epochs}]")

            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                else:
                    torch.set_grad_enabled(False)
                    self.eval()

                dataloader = dataloaders[op]

                running_total_loss = 0.0
                running_err_S = 0.0
                running_err_Sc = 0.0
                running_obj_loss = 0.0
                running_kl_loss = 0.0

                for i, data in enumerate(tqdm(dataloader)):

                    # Load sequences
                    x, _ = data
                    x = x.float().to(device)

                    # Compute original prediction
                    f_x = (self.clf(x)).squeeze(-1).detach()
                    y_hat = (f_x >= 0.5) * 1.0

                    # Compute scores via scrambler
                    scores = self(x).unsqueeze(-1).repeat(1, 1, 4)
                    pssm_S = (
                        torch.softmax(torch.log(self.B) + scores * x, dim=2)
                        .unsqueeze(1)
                        .repeat(1, num_bkgd_samples, 1, 1)
                    )
                    pssm_Sc = (
                        torch.softmax(torch.log(self.B) + x / scores, dim=2)
                        .unsqueeze(1)
                        .repeat(1, num_bkgd_samples, 1, 1)
                    )
                    log_pssm_S = torch.log(pssm_S)
                    log_pssm_Sc = torch.log(pssm_Sc)

                    # Generate samples x_S and x_Sc using masks
                    x_S = F.gumbel_softmax(
                        torch.logit(pssm_S, eps=0.01), tau=1, hard=True
                    ).float()
                    x_Sc = F.gumbel_softmax(
                        torch.logit(pssm_Sc, eps=0.01), tau=1, hard=True
                    ).float()

                    # Evaluate f_xS and f_xSc
                    f_xS = self.clf(x_S.view(-1, 500, 4)).view(-1, num_bkgd_samples).mean(-1)
                    f_xSc = (
                        self.clf(x_Sc.view(-1, 500, 4)).view(-1, num_bkgd_samples).mean(-1)
                    )

                    # Compute losses
                    obj_loss, kl_loss, err_S, err_Sc = self.compute_losses(
                        f_xS, f_xSc, y_hat, alpha, log_pssm_S, log_pssm_Sc, t_bits
                    )
                    total_loss = obj_loss + kl_mult * kl_loss

                    running_obj_loss += obj_loss.item()
                    running_err_S += err_S.item()
                    running_err_Sc += err_Sc.item()
                    running_kl_loss += kl_loss.item()
                    running_total_loss += total_loss.item()

                    if op == "train":
                        optimizer.zero_grad()
                        total_loss.backward()
                        optimizer.step()

                        if (i + 1) % log_step == 0:

                            results_dict = {
                                "train_total_loss": running_total_loss / log_step,
                                "train_obj_loss": running_obj_loss / log_step,
                                "train_kl_loss": running_kl_loss / log_step,
                            }

                            for key, value in results_dict.items():
                                print(f"{key}: {value}")

                            if wandb_log == True:
                                wandb.log(results_dict)

                            running_total_loss = 0.0
                            running_err_S = 0.0
                            running_err_Sc = 0.0
                            running_obj_loss = 0.0
                            running_kl_loss = 0.0

                if op == "val":
                    results_dict = {
                        "val_total_loss": running_total_loss / len(dataloader),
                        "val_obj_loss": running_obj_loss / len(dataloader),
                        "val_err_S": running_err_S / len(dataloader),
                        "val_err_Sc": running_err_Sc / len(dataloader),
                        "val_kl_loss": running_kl_loss / len(dataloader),
                        "epoch": epoch + 1,
                    }

                    for key, value in results_dict.items():
                        if key != "epoch":
                            print(f"{key}, {value}")

                    if wandb_log == True:
                        wandb.log(results_dict)
                    
                    stopping_loss = (
                        alpha * results_dict["val_err_S"]
                        - (1 - alpha) * results_dict["val_err_Sc"]
                        + results_dict["val_kl_loss"]
                    )
                    if stopping_loss <= best_scrambler_loss:
                        best_scrambler_loss = stopping_loss
                        if save_path is not None:
                            print("Saving new  model")
                            torch.save(
                                self.state_dict(),
                                os.path.join(os.path.join(save_path, "scrambler.pt")),
                            )
                            # Save results and params used
                            save_dict(results_dict, save_path, "scrambler_results.txt")
