import sys
import torch
import torch.nn.functional as F

from tqdm import tqdm

class InstanceExplainer:
    def __init__(
        self,
        clf,
        device,
        alpha,
        constraint,
        num_channels,
    ):
        self.clf = clf
        self.device = device
        self.alpha = alpha
        self.constraint = constraint
        self.num_channels = num_channels
        self.kernel = (
            torch.tensor(
                [[0, 1, 0], [1, -4, 1], [0, 1, 0]],
                dtype=torch.float32,
                device=self.device,
            )
            .unsqueeze(0)
            .unsqueeze(0)
        )
    
    def sample(self, x, mask, bkgd, num_samples):

        # Flatten and reshape tensors
        bkgd_flat = bkgd.view(-1, self.num_mask_entries).to(self.device)
        x_flat = x.view(
            x.shape[0], self.num_channels, self.num_mask_entries
        )  # [B, C, d^2]
        mask_flat = mask.view(x.shape[0], -1)  # [B, d^2]
        mask_reshape = mask_flat.unsqueeze(1).expand(-1, num_samples, -1)  # [B, N, d^2]

        # Make logits and sample
        # logits_S = torch.stack((mask_reshape, -mask_reshape), dim=-1)  # [B, N, d^2, 2]
        p_S = torch.stack((mask_reshape, 1-mask_reshape), dim=-1)
        logits_S = torch.logit(p_S, eps=0.02)
        mask_01_S_flat = F.gumbel_softmax(logits_S, tau=1, hard=True)[
            :, :, :, 0
        ]  # [B, N, d^2]
        """
        mask_S_flat = F.gumbel_softmax(logits_S, tau=1, hard=False)[
            :, :, :, 0
        ]  # [B, N, d^2]
        """
        # mask_S_flat = p_S[:, :, :, 0]
        x_S_flat = (
            mask_01_S_flat.unsqueeze(2) * x_flat.unsqueeze(1)
            + (1 - mask_01_S_flat.unsqueeze(2)) * bkgd_flat
        )
        x_Sc_flat = (1 - mask_01_S_flat.unsqueeze(2)) * x_flat.unsqueeze(
            1
        ) + mask_01_S_flat.unsqueeze(2) * bkgd_flat

        # Reshape
        mask_01 = mask_01_S_flat.view(-1, num_samples, self.h, self.w) 
        # mask_S = mask_S_flat.view(-1, num_samples, self.h, self.w)  # [B, N, h, w]
        x_S = x_S_flat.view(x.shape[0], num_samples, self.num_channels, self.h, self.w) # [B, N, C, h, w]
        x_Sc = x_Sc_flat.view(x.shape[0], num_samples, self.num_channels, self.h, self.w) # [B, N, C, h, w]

        return mask_01, x_S, x_Sc

    def compute_losses(self, f_xS, f_xSc, y_hat, y_0, alpha, mask):
        # Sufficiency and necesssity terms
        suff_loss = torch.mean(torch.abs(f_xS - y_hat))
        necc_loss = torch.mean(torch.abs(f_xSc - y_0))
        uni_loss = alpha * suff_loss + (1 - alpha) * necc_loss

        # L1 loss of mask
        if self.constraint == True:
            mask_flat = mask.view(mask.shape[0], -1)  # [B, d^2]
            sorted, _ = torch.sort(mask_flat, dim=-1, descending=True)
            sp_loss = torch.mean(torch.abs(sorted - self.ones_k))
        else: 
            sp_loss = torch.mean(mask)
        
        # TV loss of mask
        abs_h_diffs = torch.abs(mask[:, 1:, :] - mask[:, :-1, :])
        # sm_h_loss = abs_h_diffs[:, :, :, :].sum(dim=(2, 3))

        abs_w_diffs = torch.abs(mask[:, :, 1:] - mask[:, :, :-1])
        # sm_w_loss = abs_w_diffs[:, :, :, :].sum(dim=(2, 3))

        # sm_loss = torch.mean((sm_w_loss + sm_h_loss) / self.num_mask_entries)
        sm_loss = torch.mean(abs_w_diffs) + torch.mean(abs_h_diffs)

        # Smoothness loss (laplacian)
        # mask_reshape = mask.view(mask.shape[0] * mask.shape[1], 1, self.h, self.w)
        # mask_reshape = mask.unsqueeze(1)

        # Apply convolution to compute Laplacians for all masks
        laplacians = F.conv2d(mask.unsqueeze(0), self.kernel, padding=0)
        """
        laplacians = laplacians**2

        # Reshape laplacians back to [B, N, h, w]
        # laplacians = laplacians.view(mask.shape[0], mask.shape[1], self.h, self.w)
        laplacians = laplacians.view(mask.shape[0], self.h, self.w)

        # Take the average of Laplacians across N masks for each image (along dimension 1)
        # average_laplacians_per_image = torch.mean(laplacians, dim=1)  # Shape: [B, h, w]
        average_laplacians_per_image = laplacians

        # Take the mean of the Laplacians for each image (spatial mean)
        spatial_mean_laplacians = torch.mean(
            average_laplacians_per_image, dim=[1, 2]
        )  # Shape: [B]
        """

        # Finally, take the average across all images in the batch
        sh_loss = torch.norm(laplacians)
        # sh_loss = torch.mean(spatial_mean_laplacians)
        
        return uni_loss, sp_loss, sm_loss, sh_loss
    
    def step(self, x, y_hat, y_0, mask, bkgd, num_samples):
        # Sample
        mask_01, x_S, x_Sc = self.sample(x, mask, bkgd, num_samples)

        # Compute predictions
        f_xS = self.clf(
            x_S.view(x_S.shape[0] * x_S.shape[1], self.num_channels, self.h, self.w)
        )  # [B*N, 1]
        f_xS = f_xS.view(x_S.shape[0], x_S.shape[1], 1).mean(dim=1).squeeze(-1)

        # f(xSc)
        f_xSc = self.clf(
            x_Sc.view(x_Sc.shape[0] * x_Sc.shape[1], self.num_channels, self.h, self.w)
        )  # [B*N, 1]
        f_xSc = f_xSc.view(x_Sc.shape[0], x_Sc.shape[1], 1).mean(dim=1).squeeze(-1)

        # Compute losses
        uni_loss, sp_loss, sm_loss, sh_loss = self.compute_losses(
            f_xS, f_xSc, y_hat, y_0, self.alpha, mask
        )

        return uni_loss, sp_loss, sm_loss, sh_loss, mask_01
    

    def __call__(
        self,
        x,
        y_0,
        bkgd, 
        num_samples,
        num_steps,
        learning_rate,
        sp_lambda,
        sm_lambda,
        sh_lambda,
        area,
        C,
        M, 
        log_frac=0.5,
        init_mask="constant",
        return_logs=False,
    ):
        # Assert image has shape (1,C,H,W)
        assert len(x.shape) == 4

        # Remove gradient calculations for input
        x.requires_grad_(False)

        # Put clf, background and x on device
        self.clf = self.clf.to(self.device)
        bkgd = bkgd.to(self.device)
        x = x.to(self.device)
        self.h, self.w = x.shape[-2:]
        self.num_mask_entries = self.h * self.w

        # Area constraint
        k = int(area * self.num_mask_entries)
        ones_k = torch.zeros(self.num_mask_entries, device=self.device)
        ones_k[:k] = 1
        self.ones_k = ones_k.unsqueeze(0).unsqueeze(0)

        # Initialize list for logs
        logs = {"total_loss": [], "uni_loss": [], "sp_loss": [], "sm_loss": [], "sh_loss": []}

        """
        Initialize random binary mask
        """
        if init_mask == "rand":
            _mask = torch.randn(
                (1, *x.shape[2:]),
                dtype=torch.float32,
                device=self.device
            ) * (0.5/3) + 0.5
            # mask = torch.logit(_mask)
            mask = _mask.to(self.device)
            mask.requires_grad_() 

        elif init_mask == "uni_rand":
            _mask = torch.rand(
                (1, *x.shape[2:]),
                dtype=torch.float32,
                device=self.device
            )
            mask = _mask.to(self.device)
            mask.requires_grad_() 

        elif init_mask == "constant":
            _mask = torch.ones(
                (1, *x.shape[2:]),
                dtype=torch.float32,
                device=self.device
            )
            mask = _mask*C
            mask = mask.to(self.device)
            mask.requires_grad_() 

        """
        Initialize optimizer for mask
        """
        optimizer = torch.optim.Adam([mask], lr=learning_rate)

        """
        Compute score for original input image
        """
        score = self.clf(x).item()
        y_hat = (score >= 0.5) * 1.0

        """
        Start optimizing masks
        """
        for i in tqdm(range(num_steps)):

            """
            Perform optimization steps
            """
            uni_loss, sp_loss, sm_loss, sh_loss, mask_01 = self.step(
                x, y_hat, y_0, mask, bkgd, num_samples)
            
            total_loss = uni_loss + sp_lambda * sp_loss + sm_lambda * sm_loss + sh_lambda * sh_loss

            # Log loss terms
            logs["total_loss"].append(total_loss.item())
            logs["uni_loss"].append(uni_loss.item())
            logs["sp_loss"].append(sp_loss.item())
            logs["sm_loss"].append(sm_loss.item())
            logs["sh_loss"].append(sh_loss.item())

            if (i+1) % (num_steps*log_frac) == 0:
                tqdm.write(f"Objective loss = {uni_loss.item()}") 
                tqdm.write(f"sp loss = {sp_loss.item()}")       
                tqdm.write(f"sm loss = {sm_loss.item()}")    
                tqdm.write(f"sh loss = {sh_loss.item()}")  

            # Perform optimization step
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            with torch.no_grad():
                mask.clamp_(min=0, max=1)
                
        if return_logs == True:
            return mask.detach(), mask_01, logs
        else:
            return mask.detach(), mask_01
