import h5py
import numpy as np
from collections import Counter, namedtuple
from geomloss import SamplesLoss

from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from . import MODELS
    from .video_mae import VideoMAE
    from .audio_model import W2V2_Model
    from .fusion_model import FusionModel
    from .optimal_transport import Sinkhorn_low_level, Sinkhorn_high_level
    from .map_model import Map
except ImportError:
    from models import MODELS
    from models.video_mae import VideoMAE
    from models.audio_model import W2V2_Model
    from models.fusion_model import FusionModel
    from models.optimal_transport import Sinkhorn_low_level, Sinkhorn_high_level
    from models.map_model import Map

pi = 3.1415926535



@torch.no_grad()
def cheb_coeffs(
    K: int,
    mode: str = "T_fixed",
    device=None,
    dtype=torch.float32,
    alpha: Optional[torch.Tensor] = None,
    dolph_R: float = 60.0
):

    device = device or torch.device("cpu")
    if K <= 0:
        return torch.empty(0, device=device, dtype=dtype)

    k = torch.arange(1, K + 1, device=device, dtype=dtype)
    theta_fixed = (2 * k - 1) * (pi / (2.0 * K))

    if mode == "T_fixed":
        C = torch.cos(theta_fixed)

    elif mode == "U":
        theta0 = pi / (2.0 * K)
        denom = torch.sin(torch.tensor(theta0, device=device, dtype=dtype))
        C = torch.sin(theta_fixed) / (denom + 1e-12)

        C = C / (C.norm() + 1e-12) * (torch.cos(theta_fixed).norm() + 1e-12)

    elif mode == "T_learnable":
        if alpha is None:
            alpha = torch.tensor(1.0, device=device, dtype=dtype)
        theta_base = (pi / (2.0 * K)) * F.softplus(alpha)
        theta = (2 * k - 1) * theta_base
        C = torch.cos(theta)
        C = torch.clamp(C, -1.0, 1.0)

    elif mode == "dolph":

        Rlin = 10.0 ** (dolph_R / 20.0)
        beta = torch.cosh(torch.acosh(torch.tensor(Rlin, device=device, dtype=dtype)) / max(K - 1, 1))
        n = torch.arange(1, K + 1, device=device, dtype=dtype)
        x = beta * torch.cos(pi * n / (K + 1e-12))


        def cheb_T_m(xx, m):
            if m == 0: return torch.ones_like(xx)
            if m == 1: return xx
            Tm_2, Tm_1 = torch.ones_like(xx), xx
            for _ in range(2, m + 1):
                Tm = 2 * xx * Tm_1 - Tm_2
                Tm_2, Tm_1 = Tm_1, Tm
            return Tm_1

        num = cheb_T_m(x, K - 1)
        den = cheb_T_m(torch.tensor(beta, device=device, dtype=dtype), K - 1)
        C = num / (den + 1e-12)

        C = C / (C.norm() + 1e-12) * (torch.cos(theta_fixed).norm() + 1e-12)

    else:
        raise ValueError(f"Unknown chebyshev mode: {mode}")

    return C


class FrequencyDomainProcessing(nn.Module):
    def __init__(
        self,
        d_model,
        num_filter=2,
        dropout=0.15,

        kernel_size=None,

        cheb_mode: str = "T_fixed",
        alpha_init: float = 1.0,
        dolph_R: float = 60.0,
        **kwargs
    ):
        super().__init__()
        self.d_model = d_model
        self.num_filter = num_filter
        self.cheb_mode = cheb_mode
        self.dolph_R = dolph_R


        self.weight = nn.Parameter(torch.randn(d_model, 2))
        self.filter_bank = nn.Parameter(torch.randn(num_filter, d_model, 2))
        self.add_norm = AddNorm(d_model, dropout)


        if cheb_mode == "T_learnable":
            self.alpha = nn.Parameter(torch.tensor(alpha_init, dtype=torch.float32))
        else:
            self.register_parameter("alpha", None)


        self.kernel_size = kernel_size
        self.extra_kwargs = kwargs

    @staticmethod
    def _to_complex(w2):
        return torch.view_as_complex(w2)

    def unimodal_spectrum_compression(self, x):

        B, T, D = x.shape
        x_fft = torch.fft.rfft(x, dim=1, norm='ortho')
        power = (x_fft.real ** 2 + x_fft.imag ** 2) / (T + 0.0)


        C = cheb_coeffs(
            self.num_filter,
            mode=self.cheb_mode,
            device=x.device,
            dtype=x.dtype,
            alpha=self.alpha if self.alpha is not None else None,
            dolph_R=self.dolph_R
        )


        Kc = self._to_complex(self.filter_bank)
        w = torch.einsum('k,kd->d', C, Kc)
        w = w.unsqueeze(0).unsqueeze(0)

        compressed = power * w
        return compressed

    def forward(self, x):
        return self.unimodal_spectrum_compression(x)




@MODELS.register
class CKST(nn.Module):
    def __init__(
            self,
            source_feature,
            epsilon: float = 0.01,
            max_iter: int = 200,
            reduction: str = None,
            xi: float = 0.2,
            delta: float = 0.5,
            thresh: float = 1e-5,
            num_map_layer: int = 2,
            hidden_dim_map: int = 128,
            nu: float = 0.1,
            alpha: float = 0.95,
            backbone_type: str = "FusionModel",
            feature_dim: int = 512,
            num_classes: int = 2,
            num_filter: int = 2,
            dropout: float = 0.0,
            **kwargs
    ) -> None:
        super(CKST, self).__init__()
        self.xi = xi
        self.trans = (xi - 0.0) > 1e-8
        self.delta = delta
        self.ot_thresh = thresh
        self.num_classes = num_classes
        self.nu = nu
        self.alpha = alpha

        self.target_model = eval(backbone_type)(
            feature_dim=feature_dim,
            num_classes=num_classes,
            **kwargs,
        )

        if self.trans:
            self.load_source_feature(source_feature)
            self.sinkhorn_low_level = Sinkhorn_low_level(
                eps=epsilon,
                max_iter=max_iter,
                reduction=reduction,
                thresh=self.ot_thresh,
            )
            self.sinkhorn_high_level = Sinkhorn_high_level(
                eps=epsilon,
                max_iter=max_iter,
                reduction=reduction,
                thresh=self.ot_thresh,
            )
            fft_dim = feature_dim // 2 + 1
            self.knowledge_tran = Map(
                in_dim=fft_dim,
                out_dim=fft_dim,
                hidden_dim=hidden_dim_map,
                n_layer=num_map_layer,
            )
            self.target_mapping = Map(
                in_dim=feature_dim,
                out_dim=fft_dim,
                hidden_dim=hidden_dim_map,
                n_layer=num_map_layer,
            )
            self.source_mapping = Map(n_layer=0)
            self.loss_map = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)
            self.prototype = torch.ones(
                (num_classes, len(self.source_classes)), requires_grad=False
            ) / len(self.source_classes)
        else:
            self.prototype = None
            pass

        self.classifier = nn.Sequential(
            nn.Linear(fft_dim, num_classes),
            # nn.Softmax(),
        )
        self.loss_cls = nn.CrossEntropyLoss(reduction="mean")
        self.result = namedtuple("Result", ["output", "loss"])
        self.freq_processor = FrequencyDomainProcessing(fft_dim, num_filter, dropout)

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
        if self.trans:
            self.prototype = self.prototype.to(device)
            self.source_feature = self.source_feature.to(device)
            self.sample_probabi = self.sample_probabi.to(device)
            self.source_mean = self.source_mean.to(device)
        return super(CKST, self).to(*args, **kwargs)

    def forward(self, data_collection, train_stage=False):
        label = data_collection.label.to(self.device)
        target_feature = self.target_model.extract_feature(data_collection)

        if self.trans:

            source_feature_fft = torch.fft.rfft(self.source_feature.to(self.dtype), dim=2, norm='ortho')
            source_feature_fft = self.freq_processor(source_feature_fft)
            target_feature_fft = torch.abs(target_feature)
            source_feature_fft = torch.abs(source_feature_fft)

            target_feature = self.target_mapping(target_feature_fft)
            source_feature = self.source_mapping(source_feature_fft.to(self.dtype))
            cost_inner, pi, _ = self.sinkhorn_low_level(
                source_feature, target_feature.unsqueeze(0), self.sample_probabi
            )
            source_mean = source_feature.mean(axis=1)
            _, pi, _ = self.sinkhorn_high_level(
                source_mean,
                target_feature,
                cost_inner,
            )
            pi = pi.permute(1, 0).to(self.dtype) * len(label)
            tp = 0
            if train_stage:
                self.update_prototype(pi=pi, label=label)
            else:
                tp = torch.clamp(torch.std(pi, dim=1, keepdim=True) - self.nu, 0, 1)
                pi = (1 - tp) * self.select_prototype(pi=pi) + tp * pi
            tran = (pi.unsqueeze(-1) * source_mean.unsqueeze(0)).sum(axis=1)
            tran = self.knowledge_tran(tran)

            tran = torch.fft.irfft(tran, n=target_feature.size(-1), dim=-1, norm='ortho')
            output = self.classifier((1 - self.xi) * target_feature + self.xi * tran)
            loss = self.loss_cls(output, label) + self.delta * self.loss_map(
                target_feature, source_mean
            )
            res = self.result(output=output, loss=loss)
            return res
        else:
            output = self.classifier(target_feature)
            loss = self.loss_cls(output, label)
            res = self.result(output=output, loss=loss)

        return res

    @torch.no_grad()
    def update_prototype(self, pi, label):
        for cls in range(self.num_classes):
            mask = label == cls
            if not torch.any(mask).item():
                continue
            self.prototype[cls] = self.alpha * self.prototype[cls] + (1 - self.alpha) * pi[
                                                                                        mask, :
                                                                                        ].mean(axis=0)

    @torch.no_grad()
    def select_prototype(self, pi):
        k = torch.cdist(pi, self.prototype).argmin(axis=1)
        return self.prototype[k, :]

    def load_source_feature(self, source_feature_path):
        with h5py.File(source_feature_path, "r") as f:
            data = f["features"][:]
            label = f["labels"][:]

        classifer_source = LogisticRegression(max_iter=1000).fit(X=data, y=label)

        source_label_count = Counter(label.tolist())
        max_item = max(source_label_count.values())

        source_classes = set(source_label_count.keys())
        self.sample_probabi = torch.zeros([len(source_classes), max_item])
        source_mean = []
        feature_all_reshape = np.zeros(shape=[len(source_classes), max_item, data.shape[1]])
        for source_class in source_classes:
            mask = label == source_class
            feature = data[mask, ...]
            source_mean.append(np.mean(feature, axis=0))

            if len(feature) == max_item:
                feature_all_reshape[source_class, ...] = feature
            else:
                pad = max_item - len(feature)
                tmp = np.concatenate([feature, feature[:pad]], axis=0)
                while pad >= len(feature):
                    pad -= len(feature)
                    tmp = np.concatenate([tmp, feature[:pad]], axis=0)
                feature_all_reshape[source_class, ...] = tmp
            predict_prob = classifer_source.predict_proba(feature_all_reshape[source_class, ...])
            self.sample_probabi[source_class, ...] = F.softmax(
                torch.from_numpy(predict_prob[:, source_class] / 0.3), dim=0
            )
        self.source_feature = torch.from_numpy(feature_all_reshape)
        self.source_classes = list(source_classes)
        self.source_labels = torch.from_numpy(label)
        self.source_mean = self.source_feature.mean(axis=1)

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def set_xi(self, xi):
        self.xi = xi
