import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.feature import local_binary_pattern
from metrics.base_metrics_class import calculate_metrics_for_train

from .base_detector import AbstractDetector
from detectors import DETECTOR
from networks import BACKBONE
from loss import LOSSFUNC

import logging
logger = logging.getLogger(__name__)


@DETECTOR.register_module(module_name='lbp_4ch')
class LBP4ChDetector(AbstractDetector):
    """
    A simplified LBP-based detector that:
      1) Extracts single-channel LBP from the 3-channel RGB image.
      2) Concatenates RGB (3 channels) + LBP (1 channel) => 4 channels.
      3) Uses a single Xception backbone configured for 4-channel input.
      4) No feature-fusion module is used.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Single backbone that will handle 4-channel input
        self.backbone = self.build_backbone(config)

        # Simple LBP extractor (single-channel)
        self.lbp_extractor = LBPExtractor(
            radius=1, 
            n_points=8, 
            method='uniform'
        )

        # Prepare the loss function
        self.loss_func = self.build_loss(config)

        # Recorders (for inference metrics)
        self.prob, self.label = [], []
        self.correct, self.total = 0, 0

    def build_backbone(self, config):
        """
        Build a backbone model that can handle 4 input channels 
        by adapting the pretrained Xception 'conv1.weight'.
        """
        backbone_class = BACKBONE[config['backbone_name']]
        model_config = config['backbone_config']
        backbone = backbone_class(model_config)

        # Load the pretrained Xception state dict
        state_dict = torch.load(config['pretrained'], map_location='cpu')
        for name, weights in list(state_dict.items()):
            # Some models store "pointwise" conv as [out_chan, in_chan], so expand if needed
            if 'pointwise' in name and weights.ndim == 2:
                state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)

        # Remove the classification layer if it doesn't match
        state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k and 'classifier' not in k}

        # Remove original conv1 weights (which were for 3-ch input)
        conv1_data = state_dict.pop('conv1.weight')

        # Load the remaining partial weights
        missing, unexpected = backbone.load_state_dict(state_dict, strict=False)
        logger.info(
            f"Loaded pretrained model from {config['pretrained']}. "
            f"Missing: {missing}, Unexpected: {unexpected}"
        )

        # Now replace conv1 with a 4-channel version
        backbone.conv1 = nn.Conv2d(
            in_channels=4, 
            out_channels=32, 
            kernel_size=3, 
            stride=2, 
            padding=0, 
            bias=False
        )

        # Initialize the new 4-ch conv1
        with torch.no_grad():
            # Average across the original 3 channels => shape [32,1,3,3]
            avg_conv1_data = conv1_data.mean(dim=1, keepdim=True)
            # Repeat for 4 channels => shape [32,4,3,3]
            backbone.conv1.weight.data = avg_conv1_data.repeat(1, 4, 1, 1)

        logger.info("Modified conv1 to accept 4 input channels.")
        return backbone

    def build_loss(self, config):
        loss_class = LOSSFUNC[config['loss_func']]
        return loss_class()

    def features(self, data_dict: dict) -> torch.Tensor:
        """
        1) Extract single-channel LBP from the 3-channel image.
        2) Concatenate => 4 channels.
        3) Extract features via single backbone.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        x = data_dict['image'].to(device)  # [B, 3, H, W]
        lbp_x = self.lbp_extractor(x)       # [B, 1, H, W]

        # Combine => [B,4,H,W]
        combined = torch.cat([x, lbp_x], dim=1)

        # Forward pass
        feats = self.backbone.features(combined)
        return feats

    def classifier(self, features: torch.Tensor) -> torch.Tensor:
        """
        Use the classification head from the same backbone.
        """
        return self.backbone.classifier(features)

    def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        loss = self.loss_func(pred, label)
        return {'overall': loss}

    def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
        return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}

    def forward(self, data_dict: dict, inference=False) -> dict:
        feats = self.features(data_dict)  # => [B, C, H', W']
        pred = self.classifier(feats)     # => [B, num_classes]
        prob = torch.softmax(pred, dim=1)[:, 1]

        pred_dict = {'cls': pred, 'prob': prob, 'feat': feats}

        if inference:
            # Save probabilities and labels for metric computation
            self.prob.append(prob.detach().cpu().numpy())
            self.label.append(data_dict['label'].detach().cpu().numpy())

            # Compute batch-level accuracy
            _, pred_class = torch.max(pred, dim=1)
            correct = (pred_class == data_dict['label']).sum().item()
            self.correct += correct
            self.total += data_dict['label'].size(0)

        return pred_dict


class LBPExtractor(nn.Module):
    """
    Extracts LBP features from a single-channel grayscale representation 
    of a 3-channel image. Returns shape (B,1,H,W) in range [-1,1].
    """
    def __init__(self, radius=1, n_points=8, method='uniform'):
        super(LBPExtractor, self).__init__()
        self.radius = radius
        self.n_points = n_points
        self.method = method

    def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """
        Steps:
          1) Convert to grayscale: [B,3,H,W] -> [B,1,H,W].
          2) Move to CPU & NumPy.
          3) Scale to [0,255].
          4) Compute LBP => (H,W) per image => stack => (B,H,W).
          5) Normalize LBP to [-1,1].
          6) Return shape (B,1,H,W) as float.
        """
        if image_tensor.ndim != 4:
            raise ValueError(f"Expected 4D tensor, got shape: {image_tensor.shape}")

        # 1) Grayscale
        gray_tensor = torch.mean(image_tensor, dim=1, keepdim=True)  # [B,1,H,W]

        # 2) CPU NumPy
        gray_np = gray_tensor.cpu().numpy()  # (B,1,H,W)

        lbp_list = []
        for i in range(gray_np.shape[0]):
            # [1,H,W] -> [H,W]
            gray_2d = gray_np[i, 0, :, :]

            # Rescale [-1,1] -> [0,255]
            gray_2d_255 = ((gray_2d + 1) * 127.5).clip(0, 255).astype(np.uint8)

            # 3) Compute LBP => [H,W] in [0, n_points+2]
            lbp_2d = local_binary_pattern(
                gray_2d_255,
                P=self.n_points,
                R=self.radius,
                method=self.method
            ).astype(np.float32)

            # 4) Normalize to [-1,1]
            #    LBP values in [0, n_points+2]
            lbp_2d = (lbp_2d / (self.n_points + 2)) * 2 - 1

            lbp_list.append(lbp_2d)

        # Stack => (B,H,W)
        lbp_batch = np.stack(lbp_list, axis=0)

        # 5) => (B,1,H,W)
        lbp_tensor = torch.from_numpy(lbp_batch).unsqueeze(1).float().to(image_tensor.device)
        return lbp_tensor
