import os
import cv2
import pywt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import logging
logger = logging.getLogger(__name__)

from metrics.base_metrics_class import calculate_metrics_for_train
from detectors import DETECTOR
from networks import BACKBONE
from loss import LOSSFUNC

from .base_detector import AbstractDetector

# -------------
# Wavelet utility
# -------------
def wavelet_denoise(image: np.ndarray,
                    wavelet='db1',
                    level=3,
                    threshold_fraction=1.0) -> np.ndarray:
    coeffs = pywt.wavedec2(image, wavelet, level=level)
    cA = coeffs[0]
    new_coeffs = [cA]

    for detail_level in coeffs[1:]:
        cH, cV, cD = detail_level
        for c in (cH, cV, cD):
            if threshold_fraction <= 0.0:
                continue
            elif threshold_fraction >= 1.0:
                c[:] = 0
            else:
                flat = c.flatten()
                sorted_abs = np.sort(np.abs(flat))
                cutoff_index = int(len(sorted_abs) * threshold_fraction)
                cutoff_index = max(0, min(cutoff_index, len(sorted_abs) - 1))
                cutoff_value = sorted_abs[cutoff_index]
                c[np.abs(c) < cutoff_value] = 0
        new_coeffs.append((cH, cV, cD))

    return pywt.waverec2(new_coeffs, wavelet)


class WaveletDenoiseExtractor(nn.Module):
    def __init__(self, wavelet='db1', level=3, threshold_fraction=1.0):
        super().__init__()
        self.wavelet = wavelet
        self.level = level
        self.threshold_fraction = threshold_fraction

    def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
        gray_tensor = torch.mean(image_tensor, dim=1, keepdim=True)
        gray_np = gray_tensor.cpu().numpy()

        denoised_list = []
        for i in range(gray_np.shape[0]):
            gray_2d = gray_np[i, 0, :, :]
            gray_255 = ((gray_2d + 1) * 127.5).clip(0, 255).astype(np.uint8)
            gray_float = gray_255.astype(np.float32)
            denoised_2d = wavelet_denoise(
                gray_float,
                wavelet=self.wavelet,
                level=self.level,
                threshold_fraction=self.threshold_fraction
            )
            denoised_2d = np.clip(denoised_2d, 0, 255)
            denoised_list.append(denoised_2d)

        denoised_array = np.stack(denoised_list, axis=0).astype(np.float32)
        wavelet_tensor = torch.from_numpy(denoised_array).unsqueeze(1)
        wavelet_tensor = wavelet_tensor / 127.5 - 1.0
        return wavelet_tensor.to(image_tensor.device)


@DETECTOR.register_module(module_name='wavelet_phase_meso4')
class WaveletPhase_Meso4_Detector(AbstractDetector):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.backbone = self.build_backbone(config)
        self.wavelet_extractor = WaveletDenoiseExtractor(
            wavelet='db1',
            level=3,
            threshold_fraction=1.0
        )
        self.loss_func = self.build_loss(config)
        self.fusion_block = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )

        fusion_pretrained = config.get('fusion_pretrained', None)
        self.load_and_freeze_fusion_block_weights(fusion_pretrained)

        self.prob = []
        self.label = []
        self.correct = 0
        self.total = 0

    def build_backbone(self, config):
        backbone_class = BACKBONE[config['backbone_name']]
        model_config = config['backbone_config']
        backbone = backbone_class(model_config)

        pretrained_ckpt = config.get('pretrained', None)
        if pretrained_ckpt and os.path.isfile(pretrained_ckpt):
            logger.info(f"Loading pretrained weights from: {pretrained_ckpt}")
            state_dict = torch.load(pretrained_ckpt, map_location='cpu')
            missing, unexpected = backbone.load_state_dict(state_dict, strict=False)
            logger.info(f"Partial load from {pretrained_ckpt}. Missing: {missing}, Unexpected: {unexpected}")

        if model_config.get('inc', 3) == 4:
            old_conv = backbone.conv1
            new_conv = nn.Conv2d(
                in_channels=4,
                out_channels=8,
                kernel_size=3,
                padding=1,
                bias=False
            )
            with torch.no_grad():
                avg_weight = old_conv.weight.mean(dim=1, keepdim=True)
                new_weight = avg_weight.repeat(1, 4, 1, 1)
                new_conv.weight.data = new_weight
            backbone.conv1 = new_conv
            logger.info("Modified Meso4 conv1 to accept 4 input channels (RGB + fused wavelet+phase).")

        return backbone

    def load_and_freeze_fusion_block_weights(self, fusion_pretrained_path):
        if fusion_pretrained_path is None:
            logger.info("No 'fusion_pretrained' specified; skipping fusion_block loading.")
            return

        logger.info(f"Loading fusion_block weights from: {fusion_pretrained_path}")
        checkpoint = torch.load(fusion_pretrained_path, map_location='cpu')
        state_dict = checkpoint.get('state_dict', checkpoint)

        fusion_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('fusion_block.'):
                new_key = k.replace('fusion_block.', '')
                fusion_state_dict[new_key] = v

        missing, unexpected = self.fusion_block.load_state_dict(fusion_state_dict, strict=False)
        logger.info(f"Loaded fusion_block. Missing: {missing}, Unexpected: {unexpected}")

        for param in self.fusion_block.parameters():
            param.requires_grad = False

    def build_loss(self, config):
        loss_class = LOSSFUNC[config['loss_func']]
        return loss_class()

    def phase_without_amplitude(self, img: torch.Tensor) -> torch.Tensor:
        gray_img = torch.mean(img, dim=1, keepdim=True)
        X = torch.fft.fftn(gray_img, dim=(-2, -1))
        phase_spectrum = torch.angle(X)
        reconstructed_X = torch.exp(1j * phase_spectrum)
        reconstructed_x = torch.fft.ifftn(reconstructed_X, dim=(-2, -1))
        phase_image = torch.real(reconstructed_x)
        return phase_image

    def features(self, data_dict: dict) -> torch.Tensor:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        x_rgb = data_dict['image'].to(device)
        x_wavelet = self.wavelet_extractor(x_rgb)
        x_phase = self.phase_without_amplitude(x_rgb)
        wavelet_phase = torch.cat([x_wavelet, x_phase], dim=1)
        fused_channel = self.fusion_block(wavelet_phase)
        x_4ch = torch.cat([x_rgb, fused_channel], dim=1)
        feats = self.backbone.features(x_4ch)
        return feats

    def classifier(self, features: torch.Tensor) -> torch.Tensor:
        return self.backbone.classifier(features)

    def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        loss = self.loss_func(pred, label)
        return {'overall': loss}

    def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
        return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}

    def forward(self, data_dict: dict, inference=False) -> dict:
        feats = self.features(data_dict)
        pred = self.classifier(feats)
        prob = torch.softmax(pred, dim=1)[:, 1]

        pred_dict = {
            'cls': pred,
            'prob': prob,
            'feat': feats
        }

        if inference:
            self.prob.append(prob.detach().cpu().numpy())
            self.label.append(data_dict['label'].detach().cpu().numpy())
            _, prediction_class = torch.max(pred, 1)
            correct = (prediction_class == data_dict['label']).sum().item()
            self.correct += correct
            self.total += data_dict['label'].size(0)

        return pred_dict
