import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, AutoTokenizer, CanineModel, BertModel
from utils import AstroAttentionForPretrained, RelativePositionalEncoding, CustomEmbedding, SoftmaxAttention


class PretrainedWithCustomTransformer(nn.Module):
    def __init__(self, model_name, num_labels, num_heads, hidden_dim, num_layers, dropout, scaleD, alpha,
                 num_memory_tokens, pooling, replace_attention=False, layers_to_replace=None,
                 use_only_embeddings=True, add_Hrel=True, astro_sigmoid_nonlinearity=True,
                 attention_type='astro', dataset_name=None):  # Add attention_type parameter
        super(PretrainedWithCustomTransformer, self).__init__()
        self.use_only_embeddings = use_only_embeddings
        self.add_Hrel = add_Hrel
        self.astro_sigmoid_nonlinearity = astro_sigmoid_nonlinearity
        self.attention_type = attention_type  # Store the attention type
        self.dataset_name = dataset_name

        if model_name == 'google/canine-c':
            self.pretrained = AutoModel.from_pretrained(model_name).char_embeddings  # google/canine-c
            self.hidden_size = self.pretrained.config.hidden_size
        elif model_name == 'custom' and dataset_name == 'pathfinder32':
            self.pretrained = CustomEmbedding(vocab_size=257, embed_dim=1024, type_vocab_size=1)
            self.hidden_size = self.pretrained.embedding_dim
        elif model_name == 'custom' and dataset_name == 'listops':
            self.pretrained = CustomEmbedding(vocab_size=128, embed_dim=512, type_vocab_size=1)
            self.hidden_size = self.pretrained.embedding_dim
        elif model_name == 'custom' and dataset_name == 'aan':
            self.pretrained = CustomEmbedding(vocab_size=30000, embed_dim=512, type_vocab_size=2)
            self.hidden_size = self.pretrained.embedding_dim
        else:
            if use_only_embeddings:
                self.pretrained = AutoModel.from_pretrained(model_name).embeddings
                self.hidden_size = self.pretrained.word_embeddings.embedding_dim
            else:
                self.pretrained = AutoModel.from_pretrained(model_name)
                self.hidden_size = self.pretrained.config.hidden_size

        self.num_memory_tokens = num_memory_tokens
        self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, self.hidden_size))
        nn.init.xavier_normal_(self.memory_tokens, gain=1)
        self.pooling = pooling
        self.position_encoding = PositionalEncoding(self.hidden_size, dropout)
        self.custom_transformer = TransformerEncoder(
            embed_dim=self.hidden_size,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            scaleD=scaleD,
            alpha=alpha,
            add_Hrel=add_Hrel,
            astro_sigmoid_nonlinearity=astro_sigmoid_nonlinearity,
            attention_type=attention_type  # Pass the attention type to the transformer encoder
        )
        self.fc = nn.Linear(self.hidden_size, num_labels)

        # Replace attention layers if specified
        if not use_only_embeddings and replace_attention:
            self.replace_attention_layers(layers_to_replace, scaleD, alpha)

        self.register_buffer('memory_mask', torch.ones((1, num_memory_tokens)))

    def replace_attention_layers(self, layers_to_replace, scaleD, alpha):
        config = self.pretrained.config
        if layers_to_replace is None:
            layers_to_replace = list(range(config.num_hidden_layers))

        for idx in layers_to_replace:
            pretrained_self_attention = self.pretrained.encoder.layer[idx].attention.self
            self.pretrained.encoder.layer[idx].attention.self = AstroAttentionForPretrained(
                config, pretrained_self_attention, scaleD, alpha, self.add_Hrel, self.astro_sigmoid_nonlinearity)

    def init_memory(self, batch_size):
        memory = self.memory_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        return memory

    def forward(self, input_ids, attention_mask, current_segment=None, memory=None, token_type_ids=None):
        if token_type_ids is not None:
            token_type_ids = token_type_ids.long()  # Ensure token_type_ids are of type Long
        if self.use_only_embeddings:
            outputs = self.pretrained(input_ids=input_ids, token_type_ids=token_type_ids)
        else:
            outputs = self.pretrained(input_ids=input_ids, attention_mask=attention_mask,
                                      token_type_ids=token_type_ids).last_hidden_state
        embedded = outputs  # shape: (batch_size, sequence_length, hidden_size)
        batch_size = embedded.size(0)
        if memory is None:
            memory = self.init_memory(batch_size)
        # Concatenate memory tokens to the input embeddings
        embedded = torch.cat((memory, embedded), dim=1)
        memory_mask = self.memory_mask.expand(batch_size, -1).to(input_ids.device)
        attention_mask = torch.cat((memory_mask, attention_mask), dim=1)
        encoded_output = self.custom_transformer(embedded, attention_mask=attention_mask,
                                                 current_segment=current_segment)
        memory = encoded_output[:, :self.num_memory_tokens]

        if self.pooling == 'cls':
            cls_token_output = encoded_output[:, self.num_memory_tokens, :]  # Use the CLS token
            logits = self.fc(cls_token_output)
        else:
            pooled_output = encoded_output[:, self.num_memory_tokens:].mean(dim=1)  # Average pooling
            logits = self.fc(pooled_output)
        return logits, memory


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class AstroAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1, scaleD=100.0, alpha=0.25,
                 add_Hrel=True, astro_sigmoid_nonlinearity=True, max_len=5000, clip=10):
        super(AstroAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        self.query_projection = nn.Linear(embed_dim, embed_dim)
        self.key_projection = nn.Linear(embed_dim, embed_dim)
        self.value_projection = nn.Linear(embed_dim, embed_dim)
        self.output_projection = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
        self.scaleD = scaleD
        self.alpha = alpha
        self.max_len = max_len
        self.add_Hrel = add_Hrel
        self.astro_sigmoid_nonlinearity = astro_sigmoid_nonlinearity
        self.clip = clip

        self.relative_pos = RelativePositionalEncoding(self.head_dim, self.max_len, k=clip, method='linear')

    def phi_nonlinearity(self, x, gamma=100):
        elu = self.elu(x) + 1
        # rbf = torch.exp(-gamma * (x ** 2))
        return elu

    def forward(self, x, attention_mask=None, current_segment=None):
        batch_size, seq_len, embed_dim = x.size()
        query = self.query_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.key_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.value_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        if attention_mask is not None:
            query = query.masked_fill(attention_mask.unsqueeze(1).unsqueeze(-1) == 0, 0)
            value = value.masked_fill(attention_mask.unsqueeze(1).unsqueeze(-1) == 0, 0)
            key = key.masked_fill(attention_mask.unsqueeze(1).unsqueeze(-1) == 0, 0)
        seq_len = key.shape[-2]
        head_dim = key.shape[-1]
        scaling_factor = self.scaleD * head_dim
        if seq_len > self.max_len:
            raise ValueError("sequence length exceeds model capacity")

        # Compute relative positional encodings
        rel_pos = self.relative_pos(seq_len, current_segment=current_segment)
        R_t = rel_pos.transpose(0, 1) - (seq_len // 2) / self.clip
        phi_R_t = self.phi_nonlinearity(R_t).to(x.device)
        # R_t = self.phi_nonlinearity(rel_pos.transpose(0, 1)).to(x.device)
        R_tV = torch.matmul(phi_R_t, value)
        Hrel = R_tV

        alpha = self.alpha

        phi_key = self.phi_nonlinearity(key)
        phi_query = self.phi_nonlinearity(query)
        sum_phi_key = torch.sum(phi_key, dim=-2).reshape(batch_size, self.num_heads, -1, phi_key.shape[-1])
        raw_normalizer = (sum_phi_key) @ torch.transpose(phi_query, -2, -1)
        normalizer_sign = torch.sign(raw_normalizer)
        scaled_raw_normalizer = normalizer_sign * (torch.abs(raw_normalizer) ** alpha)
        normalizer = torch.diag_embed(scaled_raw_normalizer).reshape(batch_size, self.num_heads, seq_len, seq_len)
        reciprocal_normalizer = torch.zeros_like(normalizer)
        non_zero_mask = normalizer != 0
        reciprocal_normalizer[non_zero_mask] = 1.0 / normalizer[non_zero_mask]
        if self.astro_sigmoid_nonlinearity:
            attention_weights = reciprocal_normalizer @ (phi_query @ self.sigmoid(
                phi_key.transpose(-2, -1) @ value / scaling_factor + (Hrel / scaling_factor if self.add_Hrel else 0)))
        else:
            attention_weights = reciprocal_normalizer @ (phi_query @ (
                    phi_key.transpose(-2, -1) @ value / scaling_factor + (
                Hrel / scaling_factor if self.add_Hrel else 0)))
        weighted_values = attention_weights.transpose(1, 2).contiguous()
        output = self.output_projection(weighted_values.view(batch_size, seq_len, embed_dim))
        return self.dropout(output)


class TransformerEncoderLayerWithAstroAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1, scaleD=100.0, alpha=0.25, add_Hrel=True,
                 astro_sigmoid_nonlinearity=True, attention_type='astro'):
        super(TransformerEncoderLayerWithAstroAttention, self).__init__()
        if attention_type == 'astro':
            self.self_attention = AstroAttention(embed_dim, num_heads, dropout, scaleD, alpha, add_Hrel,
                                                 astro_sigmoid_nonlinearity)
        else:
            self.self_attention = SoftmaxAttention(embed_dim, num_heads, dropout)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, attention_mask=None, current_segment=None):
        src2 = self.self_attention(src, attention_mask=attention_mask, current_segment=current_segment)
        src = src + self.dropout(src2)
        src = self.norm1(src)
        src2 = self.feedforward(src)
        src = src + self.dropout(src2)
        src = self.norm2(src)
        return src


class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, num_layers, dropout=0.1, scaleD=100.0, alpha=0.25,
                 add_Hrel=True, astro_sigmoid_nonlinearity=True, attention_type='astro'):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayerWithAstroAttention(embed_dim, num_heads, hidden_dim, dropout, scaleD, alpha,
                                                      add_Hrel, astro_sigmoid_nonlinearity, attention_type)
            for _ in range(num_layers)
        ])

    def forward(self, src, attention_mask=None, current_segment=None):
        for layer in self.layers:
            src = layer(src, attention_mask=attention_mask, current_segment=current_segment)
        return src


if __name__ == "__main__":
    from dataloader import create_data_loader
    from tqdm import tqdm
    import torchmetrics
    from transformers import AutoTokenizer

    # Model and dataset parameters
    dataset_name = 'imdb_lra'
    model_name = 'google/canine-c' if dataset_name in ['imdb_lra', 'imdb_long', 'cifar10',
                                                       'listops'] else 'bert-base-uncased'
    batch_size = 32
    num_heads = 8
    hidden_dim = 1024
    num_layers = 1
    dropout = 0.1
    scaleD = 100.0
    alpha = 0.25
    num_memory_tokens = 10
    split = 'eval'
    pooling = 'average'
    replace_attention = False
    layers_to_replace = None
    add_Hrel = True
    astro_sigmoid_nonlinearity = True
    attention_type = 'astro'
    max_length = 2048

    # Initialize model
    model = PretrainedWithCustomTransformer(
        model_name=model_name,
        num_labels=2 if dataset_name in ['imdb', 'imdb_long', 'imdb_lra'] else 10,
        num_heads=num_heads,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        scaleD=scaleD,
        alpha=alpha,
        num_memory_tokens=num_memory_tokens,
        pooling=pooling,
        replace_attention=replace_attention,
        layers_to_replace=layers_to_replace,
        add_Hrel=add_Hrel,
        astro_sigmoid_nonlinearity=astro_sigmoid_nonlinearity,
        attention_type=attention_type,
        dataset_name=dataset_name
    )

    print("Model Architecture:\n")
    print(model)

    # Load data
    data_loader, _ = create_data_loader(
        model_name,
        dataset_name,
        batch_size=batch_size,
        max_length=max_length,
        split=split,
        shuffle=False,
        sample_percentage=1,
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Accuracy metric
    accuracy_metric = torchmetrics.Accuracy(
        task='binary' if dataset_name in ['imdb', 'imdb_long', 'imdb_lra'] else 'multiclass',
        num_classes=2 if dataset_name in ['imdb', 'imdb_long', 'imdb_lra'] else 10
    ).to(device)

    model.eval()

    with torch.no_grad():
        for data in tqdm(data_loader, desc="Evaluating"):
            if dataset_name in ['imdb', 'imdb_long']:  # For datasets like IMDb and IMDb_Long
                input_ids = data['input_ids'].to(device)
                attention_mask = data['attention_mask'].to(device)
                labels = data['labels'].to(device)
            else:  # For datasets like cifar10, listops, imdb_lra
                inputs, labels = data  # For these datasets, data is a list with inputs and labels
                inputs = {key: value.squeeze(1).to(device) for key, value in
                          inputs.items()}  # Squeeze to remove extra dimension
                labels = labels.squeeze(1).to(device)  # Squeeze to remove extra dimension
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)
                labels = labels.to(device)  # Labels are directly passed as a tensor in these datasets

            outputs, memory = model(input_ids=input_ids, attention_mask=attention_mask, current_segment=0)

            preds = torch.argmax(outputs, dim=1)
            accuracy_metric.update(preds, labels)

    accuracy = accuracy_metric.compute()
    print(f"\nAccuracy on {split} set: {accuracy * 100:.2f}%")
