from abc import abstractmethod, ABC
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

class FeatureDropout(nn.Module):

    def __init__(self, p=0.2):
        super().__init__()
        self.p = p

    def forward(self, x, epoch=None):
        if not self.training or self.p == 0:
            return x
        B, S, D = x.shape
        mask = torch.ones(B, 1, D, device=x.device)
        drop_idx = torch.rand(B, D, device=x.device) < self.p
        mask[drop_idx.unsqueeze(1)] = 0
        return x * mask

class AdvancedFeatureDropout(nn.Module):
    def __init__(self,
                 min_p=0.0,
                 max_p=0.5,
                 curriculum_epochs=30,
                 min_keep=1,
                 main_feature_idx=None,
                 noise_p=0.2,
                 noise_std=0.1):
        super().__init__()
        self.min_p = min_p
        self.max_p = max_p
        self.curr_epochs = curriculum_epochs
        self.min_keep = min_keep
        self.main_idx = (torch.as_tensor(main_feature_idx)
                         if main_feature_idx is not None else None)
        self.noise_p = noise_p
        self.noise_std = noise_std

    def _curr_p(self, epoch):
        if (epoch is None) or (self.curr_epochs <= 0):
            return self.max_p
        r = min(epoch / self.curr_epochs, 1.0)
        return self.min_p + (self.max_p - self.min_p) * r

    def forward(self, x: torch.Tensor, epoch=None):
        if not self.training:
            return x
        B, S, D = x.shape
        device = x.device

        if self.noise_p > 0.0:
            elem_keep = torch.empty_like(x).bernoulli_(1 - self.noise_p)
            if self.noise_std > 0:
                noise = torch.randn_like(x) * self.noise_std
                x = torch.where(elem_keep.bool(), x, noise)
            else:
                x = x * elem_keep
        p = self._curr_p(epoch)
        if p <= 0.0:
            return x

        rand = torch.rand(B, D, device=device)

        if self.main_idx is not None and len(self.main_idx) > 0:
            rand[:, self.main_idx] = 1.1

        keep = rand > p  # bool (B,D)

        num_keep = keep.sum(dim=1, keepdim=True)  # (B,1)
        need_fix = (num_keep < self.min_keep)  # (B,1) bool
        if need_fix.any():
            kth_val = torch.kthvalue(
                rand, k=self.min_keep, dim=1).values.unsqueeze(1)  # (B,1)
            keep = keep | (rand >= kth_val)

        x = x * keep.unsqueeze(1)

        return x


class AbstractEndoModel(nn.Module, ABC):
    def __init__(
            self,
            input_dim,
            output_dim,
            num_users,
            num_sports,
            num_genders=1,
            user_dim=5,
            sport_dim=5,
            gender_dim=5,
            device_dim=5,
            num_devices=3,
            hidden_dim=64,
            includeUser=True,
            includeSport=True,
            includeGender=True,
            includeDevice=False,
            includeTemporal=True,
            contrastive_loss=True,
            includeFullTemporal=True,
            fullTemporalLength=None,
            feature_dropout=True,
            advanced_feature_dropout=True
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_users = num_users
        self.num_sports = num_sports
        self.num_genders = num_genders
        self.num_devices = num_devices
        self.includeDevice = includeDevice
        self.includeUser = includeUser
        self.includeSport = includeSport
        self.includeGender = includeGender
        self.includeTemporal = includeTemporal
        self.includeFullTemporal = includeFullTemporal
        self.fullTemporalLength = fullTemporalLength
        self.feature_dropout = feature_dropout
        self.advanced_feature_dropout = advanced_feature_dropout

        self.user_dim = user_dim
        self.sport_dim = sport_dim
        self.gender_dim = gender_dim
        self.device_dim = device_dim
        self.hidden_dim = hidden_dim
        self.contrastive_loss = contrastive_loss

    @property
    @abstractmethod
    def name(self):
        pass

    @abstractmethod
    def forward(
            self,
            main_input,
            user_input=None,
            sport_input=None,
            gender_input=None,
            device_input=None,
            context_in1=None,
            context_in2=None,
            full_context_in1=None,
            full_context_in2=None,
            epoch=None
    ):
        pass

class EndoLSTMModel_Full1(AbstractEndoModel):

    @property
    def name(self):
        return 'lstm_full1'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if self.includeUser:
            self.user_embedding = nn.Embedding(self.num_users, self.user_dim, padding_idx=0)
        if self.includeSport:
            self.sport_embedding = nn.Embedding(self.num_sports, self.sport_dim, padding_idx=0)
        if self.includeGender:
            self.gender_embedding = nn.Embedding(self.num_genders, self.gender_dim, padding_idx=0)
        if self.includeDevice:
            self.device_embedding = nn.Embedding(self.num_devices, self.device_dim, padding_idx=0)

        if self.includeTemporal:
            self.context_lstm_1 = nn.LSTM(
                input_size=self.input_dim + 1,
                hidden_size=self.hidden_dim,
                batch_first=True,
                bidirectional=False)
            self.context_lstm_2 = nn.LSTM(
                input_size=self.output_dim,
                hidden_size=self.hidden_dim,
                batch_first=True,
                bidirectional=False)
            self.context_projection = nn.Sequential(
                nn.Dropout(p=0.1),
                nn.Linear(self.hidden_dim * 2, self.hidden_dim),
                nn.GELU()
            )
            context_dim = self.hidden_dim
        else:
            context_dim = 0

        if self.includeFullTemporal:
            time_emb_dim = 32
            full_context_dim = self.hidden_dim

            self.time_encoding = nn.Sequential(
                nn.Linear(1, time_emb_dim),
                nn.ReLU(),
                nn.Linear(time_emb_dim, time_emb_dim)
            )

            self.workout_feature_lstm = nn.LSTM(
                input_size=self.input_dim + 1 + time_emb_dim,
                hidden_size=self.hidden_dim // 2,
                batch_first=True,
                bidirectional=True
            )

            self.workout_hr_lstm = nn.LSTM(
                input_size=self.output_dim + time_emb_dim,
                hidden_size=self.hidden_dim // 2,
                batch_first=True,
                bidirectional=True
            )

            self.cross_workout_gru = nn.GRU(
                input_size=self.hidden_dim * 2,
                hidden_size=self.hidden_dim,
                batch_first=True,
                bidirectional=False
            )

            self.workout_attention = nn.MultiheadAttention(
                embed_dim=self.hidden_dim,
                num_heads=4,
                batch_first=True
            )

            self.feature_fusion = nn.Sequential(
                nn.Linear(self.hidden_dim * 2, self.hidden_dim),
                nn.LayerNorm(self.hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(self.hidden_dim, full_context_dim)
            )

        if self.feature_dropout:
            if self.advanced_feature_dropout:
                self.feature_dropout_layer = AdvancedFeatureDropout()
            else:
                self.feature_dropout_layer = FeatureDropout(p=0.2)

        base_input_dim = self.input_dim
        if self.includeUser:
            base_input_dim += self.user_dim
        if self.includeSport:
            base_input_dim += self.sport_dim
        if self.includeGender:
            base_input_dim += self.gender_dim
        if self.includeDevice:
            base_input_dim += self.device_dim
        if self.includeTemporal:
            base_input_dim += context_dim
        if self.includeFullTemporal:
            base_input_dim += full_context_dim

        self.lstm1 = nn.LSTM(
            input_size=base_input_dim,
            hidden_size=self.hidden_dim,
            batch_first=True,
            bidirectional=False)
        self.dropout1 = nn.Dropout(p=0.2)
        self.lstm2 = nn.LSTM(
            input_size=self.hidden_dim,
            hidden_size=self.hidden_dim,
            batch_first=True,
            bidirectional=False)
        self.dropout2 = nn.Dropout(p=0.2)

        self.output_layer = nn.Linear(self.hidden_dim, self.output_dim)
        self.activation = nn.GELU()

    def forward(
            self,
            main_input,
            user_input=None,
            sport_input=None,
            gender_input=None,
            device_input=None,
            context_in1=None,
            context_in2=None,
            full_context_in1=None,
            full_context_in2=None,
            epoch=None,
    ):
        B, S, _ = main_input.shape

        if self.feature_dropout and self.advanced_feature_dropout:
            main_input = self.feature_dropout_layer(main_input, epoch)

        predict_vector = main_input

        if self.includeUser and user_input is not None:
            user_emb = self.user_embedding(user_input)
            predict_vector = torch.cat([predict_vector, user_emb], dim=2)

        if self.includeSport and sport_input is not None:
            sport_emb = self.sport_embedding(sport_input)
            predict_vector = torch.cat([predict_vector, sport_emb], dim=2)

        if self.includeGender and gender_input is not None:
            gender_emb = self.gender_embedding(gender_input)
            predict_vector = torch.cat([predict_vector, gender_emb], dim=2)

        if self.includeDevice:
            assert device_input is not None, f"device input is none!:check value:{device_input}"
        if self.includeDevice and device_input is not None:
            device_emb = self.device_embedding(device_input)
            predict_vector = torch.cat([predict_vector, device_emb], dim=2)

        if self.includeTemporal and (context_in1 is not None) and (context_in2 is not None):
            if self.feature_dropout:
                context_in1 = self.feature_dropout_layer(context_in1, epoch)
            c_out1, _ = self.context_lstm_1(context_in1)
            c_out2, _ = self.context_lstm_2(context_in2)
            c_cat = torch.cat([c_out1, c_out2], dim=2)
            c_proj = self.context_projection(c_cat)
            predict_vector = torch.cat([c_proj, predict_vector], dim=2)

        if self.includeFullTemporal:
            assert full_context_in1 is not None and full_context_in2 is not None
            if self.feature_dropout:
                B, P, S, D = full_context_in1.shape
                fc1_flat = full_context_in1.view(B * P, S, D)
                fc1_flat = self.feature_dropout_layer(fc1_flat)
                full_context_in1 = fc1_flat.view(B, P, S, D)
            B, P, S, D = full_context_in1.shape
            time_diffs = full_context_in1[..., -1].reshape(-1, 1)
            time_enc = self.time_encoding(time_diffs)
            time_enc = time_enc.view(B, P, S, -1)


            feat_with_time = torch.cat([full_context_in1, time_enc], dim=-1)
            hr_with_time = torch.cat([full_context_in2, time_enc], dim=-1)
            feat_flat = feat_with_time.view(B * P, S, -1)
            hr_flat = hr_with_time.view(B * P, S, -1)

            feat_encoded, (feat_h, _) = self.workout_feature_lstm(feat_flat)
            hr_encoded, (hr_h, _) = self.workout_hr_lstm(hr_flat)

            feat_workout = feat_h.transpose(0, 1).contiguous().view(B * P, -1)
            hr_workout = hr_h.transpose(0, 1).contiguous().view(B * P, -1)

            feat_workout = feat_workout.view(B, P, -1)
            hr_workout = hr_workout.view(B, P, -1)

            workout_combined = torch.cat([feat_workout, hr_workout], dim=-1)
            cross_workout, _ = self.cross_workout_gru(workout_combined)
            query = cross_workout[:, -1:, :]
            attn_out, attn_weights = self.workout_attention(query, cross_workout, cross_workout)
            final_context = torch.cat([
                cross_workout[:, -1, :],
                attn_out.squeeze(1)
            ], dim=-1)
            context_feature = self.feature_fusion(final_context)
            context_feature = context_feature.unsqueeze(1).expand(-1, S, -1)
            predict_vector = torch.cat([predict_vector, context_feature], dim=2)


        out1, _ = self.lstm1(predict_vector)
        out1 = self.dropout1(out1)
        out2, _ = self.lstm2(out1)
        out2 = self.dropout2(out2)

        logits = self.output_layer(out2)
        predictions = logits
        if self.contrastive_loss:
            return predictions, out2
        return predictions

def contrastive_loss_func(embeddings, user_ids, temperature=0.1):
    if embeddings.ndim == 3:
        embeddings = embeddings[:, -1, :]  #
    embeddings = F.normalize(embeddings, dim=1)

    sim_matrix = torch.matmul(embeddings, embeddings.T)
    sim_matrix = sim_matrix / temperature

    mask = (user_ids.unsqueeze(1) == user_ids.unsqueeze(0)).float()
    mask = mask - torch.eye(mask.size(0), device=mask.device)

    exp_sim = torch.exp(sim_matrix) * (1 - torch.eye(mask.size(0), device=mask.device))
    pos_sim = exp_sim * mask
    neg_sim = exp_sim * (1 - mask)

    pos_sum = pos_sim.sum(dim=1) + 1e-8
    neg_sum = neg_sim.sum(dim=1) + 1e-8

    loss = -torch.log(pos_sum / (pos_sum + neg_sum + 1e-8) + 1e-8)
    valid_mask = mask.sum(dim=1) > 0
    if valid_mask.sum() == 0:
        return torch.tensor(0.0, device=embeddings.device)
    loss = loss[valid_mask].mean()
    return loss