from torch.utils.data import DataLoader, Dataset

import torch
from torch import nn, optim, Tensor
from torch.utils import data


class ADA(nn.Module):
    def __init__(self, dim, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)

        # channel
        self.mlp = nn.Sequential(
            nn.Linear(in_features=dim, out_features=dim // reduction),
            nn.ReLU(),
            nn.Linear(in_features=dim // reduction, out_features=dim)
        )

    def forward(self, x):
        # logger.warning(f'{x.shape=}')
        b, l, c = x.shape
        x = x.transpose(-2, -1)

        # channel
        avg_out = self.avg_pool(x)
        avg_out = self.mlp(avg_out.view(b, -1))
        max_out = self.max_pool(x)
        max_out = self.mlp(max_out.view(b, -1))
        channel_out = avg_out + max_out
        channel_scale = torch.sigmoid(channel_out)
        channel_scale = channel_scale.unsqueeze(-1).expand_as(x)
        x = x * channel_scale

        x = x.transpose(-2, -1)
        # logger.warning(f'{x.shape=}')
        # logger.debug(f'{x=}')
        return x


class MLP(nn.Module):
    def __init__(self, dim: int = 1024):
        super().__init__()
        self.act = nn.ReLU()
        self.fc1 = nn.Linear(in_features=dim, out_features=2 * dim)
        self.fc2 = nn.Linear(in_features=2 * dim, out_features=dim)
        # nn.init.xavier_uniform_(self.fc1.weight)
        # nn.init.xavier_uniform_(self.fc2.weight)
        # self.fc1.bias.data.fill_(0)
        # self.fc2.bias.data.fill_(0)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class Cluster(nn.Module):
    def __init__(self, dim_size, identifier, matrix=None):
        super().__init__()
        self.identifier = identifier

        # 3 common modules for clustering
        # self.encoder = nn.Linear(in_features=dim_size, out_features=dim_size)
        # self.encoder = MLP(dim_size)
        self.encoder = nn.Sequential(ADA(dim_size), MLP(dim_size))
        self.en_norm = nn.LayerNorm(dim_size)

        # baseline
        self.matrix = matrix

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        z = self.encoder(x)
        z = self.en_norm(z)
        return z


class CoupleDataset(data.Dataset):
    def __init__(self, embeds, labels):
        self.length = len(embeds)
        self.embeds = embeds
        self.labels = labels

    def __getitem__(self, item):
        return self.embeds[item], self.labels[item]

    def __len__(self):
        return self.length


class MarkedDataset(Dataset):
    def __init__(self, comments, labels, tokenizer, max_length=1024):
        self.comments = comments
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.comments)

    def __getitem__(self, index):
        comment = self.comments[index]
        label = self.labels[index]

        encoding = self.tokenizer.encode_plus(comment, return_tensors="pt", truncation=True)
        # encoding = self.tokenizer.encode_plus(comment,
        #     return_tensors="pt", padding="max_length", max_length=self.max_length, truncation=True)
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label,
        }
