import os
import cv2
import pywt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import logging
logger = logging.getLogger(__name__)

from metrics.base_metrics_class import calculate_metrics_for_train
from detectors import DETECTOR
from networks import BACKBONE
from loss import LOSSFUNC

from .base_detector import AbstractDetector

# -------------
# Wavelet utility
# -------------
def wavelet_denoise(image: np.ndarray,
                    wavelet='db1',
                    level=3,
                    threshold_fraction=1.0) -> np.ndarray:
    """
    Perform wavelet 'denoising' by zeroing out some detail coefficients:
      - threshold_fraction <= 0.0: Keep ALL detail (no denoise).
      - threshold_fraction >= 1.0: Remove ALL detail.
      - 0 < threshold_fraction < 1: Zero out the smallest fraction of detail.
    """
    # 1) Decompose
    coeffs = pywt.wavedec2(image, wavelet, level=level)
    cA = coeffs[0]  # Low-frequency approx
    new_coeffs = [cA]

    # 2) Threshold details
    for detail_level in coeffs[1:]:
        cH, cV, cD = detail_level
        for c in (cH, cV, cD):
            if threshold_fraction <= 0.0:
                # Keep all details
                continue
            elif threshold_fraction >= 1.0:
                # Remove all details
                c[:] = 0
            else:
                # PARTIAL denoise => zero out the smallest fraction
                flat = c.flatten()
                sorted_abs = np.sort(np.abs(flat))
                cutoff_index = int(len(sorted_abs) * threshold_fraction)
                cutoff_index = max(0, min(cutoff_index, len(sorted_abs) - 1))
                cutoff_value = sorted_abs[cutoff_index]
                c[np.abs(c) < cutoff_value] = 0
        new_coeffs.append((cH, cV, cD))

    # 3) Reconstruct
    return pywt.waverec2(new_coeffs, wavelet)


class WaveletDenoiseExtractor(nn.Module):
    """
    Converts an RGB image into a single-channel grayscale image and applies
    wavelet partial denoising. Output scaled to [-1,1].
    """
    def __init__(self, wavelet='db1', level=3, threshold_fraction=1.0):
        super().__init__()
        self.wavelet = wavelet
        self.level = level
        self.threshold_fraction = threshold_fraction

    def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """
        image_tensor: [B,3,H,W] in range [-1,1].
        Returns: wavelet-denoised grayscale [B,1,H,W] in range [-1,1].
        """
        # 1) Convert (B,3,H,W) to (B,1,H,W)
        gray_tensor = torch.mean(image_tensor, dim=1, keepdim=True)

        # 2) Move to CPU NumPy
        gray_np = gray_tensor.cpu().numpy()  # shape (B,1,H,W)

        denoised_list = []
        for i in range(gray_np.shape[0]):
            gray_2d = gray_np[i, 0, :, :]

            # Convert from [-1,1] -> [0..255]
            gray_255 = ((gray_2d + 1) * 127.5).clip(0, 255).astype(np.uint8)
            gray_float = gray_255.astype(np.float32)

            # PARTIAL zeroing of detail in wavelet domain
            denoised_2d = wavelet_denoise(
                gray_float,
                wavelet=self.wavelet,
                level=self.level,
                threshold_fraction=self.threshold_fraction
            )
            denoised_2d = np.clip(denoised_2d, 0, 255)
            denoised_list.append(denoised_2d)

        # 3) Stack => (B,H,W) => (B,1,H,W)
        denoised_array = np.stack(denoised_list, axis=0).astype(np.float32)
        wavelet_tensor = torch.from_numpy(denoised_array).unsqueeze(1)

        # 4) Scale back to [-1,1]
        wavelet_tensor = wavelet_tensor / 127.5 - 1.0
        return wavelet_tensor.to(image_tensor.device)


# -------------
# LBP utility
# -------------
from skimage.feature import local_binary_pattern

class LBPExtractor(nn.Module):
    """
    Extracts LBP from a grayscale version of the input. Returns (B,1,H,W).
    """
    def __init__(self, radius=1, n_points=8, method='uniform'):
        super().__init__()
        self.radius = radius
        self.n_points = n_points
        self.method = method

    def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """
        Steps:
          1) Convert to grayscale: [B,3,H,W] -> [B,1,H,W].
          2) Move to CPU & NumPy.
          3) Scale [-1,1] -> [0,255].
          4) Compute LBP => [H,W].
          5) Normalize LBP to [-1,1].
        """
        if image_tensor.ndim != 4:
            raise ValueError(f"Expected 4D tensor, got shape: {image_tensor.shape}")

        # 1) Grayscale
        gray_tensor = torch.mean(image_tensor, dim=1, keepdim=True)  # [B,1,H,W]

        # 2) CPU NumPy
        gray_np = gray_tensor.cpu().numpy()  # (B,1,H,W)

        lbp_list = []
        for i in range(gray_np.shape[0]):
            # [1,H,W] -> [H,W]
            gray_2d = gray_np[i, 0, :, :]

            # Rescale [-1,1] -> [0,255]
            gray_2d_255 = ((gray_2d + 1) * 127.5).clip(0, 255).astype(np.uint8)

            # 3) Compute LBP => in [0, n_points+2]
            lbp_2d = local_binary_pattern(
                gray_2d_255,
                P=self.n_points,
                R=self.radius,
                method=self.method
            ).astype(np.float32)

            # 4) Normalize to [-1,1]
            max_val = self.n_points + 2.0  # typical upper bound for 'uniform' LBP
            lbp_2d = (lbp_2d / max_val) * 2.0 - 1.0

            lbp_list.append(lbp_2d)

        # Stack => (B,H,W) => (B,1,H,W)
        lbp_batch = np.stack(lbp_list, axis=0)
        lbp_tensor = torch.from_numpy(lbp_batch).unsqueeze(1).float()

        return lbp_tensor.to(image_tensor.device)

@DETECTOR.register_module(module_name='wavelet_lbp_4ch')
class WaveletLBP4chDetector(AbstractDetector):
    """
    A detector that uses:
      - 3-channel RGB
      - A fused channel of wavelet-denoised grayscale + LBP
    => total of 4 channels for the backbone.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Build 4-channel backbone (similar to your previous spsl/wavelet code)
        self.backbone = self.build_backbone(config)

        # Wavelet extractor
        self.wavelet_extractor = WaveletDenoiseExtractor(
            wavelet='db1',
            level=3,
            threshold_fraction=1.0
        )

        # LBP extractor
        self.lbp_extractor = LBPExtractor(
            radius=1,
            n_points=8,
            method='uniform'
        )

        # Loss function
        self.loss_func = self.build_loss(config)

        # Fusion block: merges wavelet + LBP => 1 channel
        # (Input shape: [B,2,H,W], output shape: [B,1,H,W])
        self.fusion_block = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )

        # Recorders for inference
        self.prob = []
        self.label = []
        self.correct = 0
        self.total = 0

    def build_backbone(self, config):
        """
        Build an Xception-like backbone that can handle 4 input channels:
        3-ch RGB + 1-ch fused (wavelet+LBP).
        """
        backbone_class = BACKBONE[config['backbone_name']]
        model_config = config['backbone_config']
        backbone = backbone_class(model_config)

        # Load the pretrained Xception weights
        state_dict = torch.load(config['pretrained'], map_location='cpu')
        for name, weights in list(state_dict.items()):
            # Some models store "pointwise" conv as [out_chan, in_chan], so expand if needed
            if 'pointwise' in name and weights.ndim == 2:
                state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)

        # Remove classifier layer if mismatch
        state_dict = {k: v for k, v in state_dict.items()
                      if 'fc' not in k and 'classifier' not in k}

        # Pop out original conv1 weights (3-ch)
        conv1_data = state_dict.pop('conv1.weight', None)

        # Load partial weights
        missing, unexpected = backbone.load_state_dict(state_dict, strict=False)
        logger.info(
            f"Loaded pretrained model from {config['pretrained']}. "
            f"Missing: {missing}, Unexpected: {unexpected}"
        )

        # Replace conv1 with a 4-channel version
        backbone.conv1 = nn.Conv2d(
            in_channels=4,
            out_channels=32,
            kernel_size=3,
            stride=2,
            padding=0,
            bias=False
        )

        # If we found original 3-ch conv1_data, adapt to 4-ch
        if conv1_data is not None:
            with torch.no_grad():
                # conv1_data shape: [32,3,3,3]
                # We'll average across channels => shape [32,1,3,3]
                avg_conv1_data = conv1_data.mean(dim=1, keepdim=True)
                # Then replicate to 4 channels => shape [32,4,3,3]
                new_conv1_data = avg_conv1_data.repeat(1, 4, 1, 1)
                backbone.conv1.weight.data = new_conv1_data

        logger.info("Modified conv1 to accept 4 input channels (RGB + fused wavelet+LBP).")
        return backbone

    def build_loss(self, config):
        """
        Build the specified loss function.
        """
        loss_class = LOSSFUNC[config['loss_func']]
        return loss_class()

    def features(self, data_dict: dict) -> torch.Tensor:
        """
        Steps:
         1) Original RGB: (B,3,H,W)
         2) Wavelet-denoised grayscale: (B,1,H,W)
         3) LBP from grayscale: (B,1,H,W)
         4) Fuse wavelet + LBP => (B,1,H,W) via self.fusion_block
         5) Concatenate with RGB => (B,4,H,W)
         6) Pass to backbone.features()
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        x_rgb = data_dict['image'].to(device)  # (B,3,H,W)

        # (B,1,H,W) wavelet
        x_wavelet = self.wavelet_extractor(x_rgb)

        # (B,1,H,W) lbp
        x_lbp = self.lbp_extractor(x_rgb)

        # Fuse wavelet + LBP => single channel
        # => shape [B,2,H,W] -> [B,1,H,W]
        wavelet_lbp = torch.cat([x_wavelet, x_lbp], dim=1)
        fused_channel = self.fusion_block(wavelet_lbp)

        # Combine fused channel with RGB => 4 channels
        x_4ch = torch.cat([x_rgb, fused_channel], dim=1)

        # Extract features from the backbone
        feats = self.backbone.features(x_4ch)
        return feats

    def classifier(self, features: torch.Tensor) -> torch.Tensor:
        """
        Pass features through the backbone's classifier head.
        """
        return self.backbone.classifier(features)

    def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
        """
        Compute and return the loss dict.
        """
        label = data_dict['label']
        pred = pred_dict['cls']
        loss = self.loss_func(pred, label)
        return {'overall': loss}

    def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
        """
        Compute batch metrics for training.
        """
        label = data_dict['label']
        pred = pred_dict['cls']
        auc, eer, acc, ap = calculate_metrics_for_train(
            label.detach(),
            pred.detach()
        )
        return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}

    def forward(self, data_dict: dict, inference=False) -> dict:
        """
        Forward pass:
         1) Extract wavelet+LBP fused features.
         2) Classify => logits.
         3) Compute probabilities and (optionally) record for inference.
        """
        feats = self.features(data_dict)
        pred = self.classifier(feats)
        prob = torch.softmax(pred, dim=1)[:, 1]

        pred_dict = {
            'cls': pred,   # logits
            'prob': prob,  # scalar prob for the "positive" class
            'feat': feats
        }

        if inference:
            # Collect probabilities and labels
            self.prob.append(prob.detach().cpu().numpy())
            self.label.append(data_dict['label'].detach().cpu().numpy())

            # Track accuracy
            _, prediction_class = torch.max(pred, 1)
            correct = (prediction_class == data_dict['label']).sum().item()
            self.correct += correct
            self.total += data_dict['label'].size(0)

        return pred_dict
