import math
import os
import pickle
from glob import glob
import random

import torch
import torch.nn as nn
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from typing import List, Optional, Tuple, Union


class RelativePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, k=10, method='linear'):
        super(RelativePositionalEncoding, self).__init__()
        self.d_model = d_model
        self.k = k
        self.method = method

        # Create a matrix of [2k + 1, d_model] for relative positional embeddings
        self.relative_positions = nn.Embedding(2 * max_len + 1, d_model)

        # Placeholder for transform layers (will be initialized dynamically in the forward pass)
        self.linear_transform = None
        self.conv_transform = None

    def forward(self, length, current_segment=None):
        device = self.relative_positions.weight.device

        # Get the relative positions matrix of shape [length, length]
        range_vec = torch.arange(length, device=device)
        distance_mat = range_vec[None, :] - range_vec[:, None]

        # Clip the distances to the range [-k, k]
        distance_mat = torch.clamp(distance_mat, -self.k, self.k)

        # Ensure that all distances are positive by adding k
        distance_mat += (self.k + current_segment * self.k)

        # Convert distances to embeddings
        relative_position_embeddings = self.relative_positions(distance_mat)  # Shape: [length, length, d_model]

        if self.method == 'linear':
            # Initialize the linear layer dynamically
            if self.linear_transform is None or self.linear_transform.in_features != length:
                self.linear_transform = nn.Linear(length, 1, bias=False).to(device)
            # Apply linear transformation and squeeze the extra dimension
            linear_transformed = self.linear_transform(relative_position_embeddings.permute(0, 2, 1)).squeeze(-1)
            return linear_transformed
        elif self.method == 'conv':
            # Initialize the conv layer dynamically
            if self.conv_transform is None or self.conv_transform.in_channels != length:
                self.conv_transform = nn.Conv1d(in_channels=length, out_channels=1, kernel_size=1).to(device)
            # Apply convolutional transformation and squeeze the extra dimension
            conv_transformed = self.conv_transform(relative_position_embeddings.permute(1, 0, 2)).squeeze()
            return conv_transformed
        else:
            # Initialize the linear layer dynamically
            if self.linear_transform is None or self.linear_transform.in_features != length:
                self.linear_transform = nn.Linear(self.d_model, 1, bias=False).to(device)
            # Apply linear transformation and squeeze the extra dimension
            linear_transformed = self.linear_transform(relative_position_embeddings).squeeze()
            return linear_transformed


class AstroAttentionForPretrained(nn.Module):
    def __init__(self, config, pretrained_self_attention=None, scaleD=100.0, alpha=0.25, add_Hrel=True,
                 astro_sigmoid_nonlinearity=True, max_len=5000, clip=10):
        super(AstroAttentionForPretrained, self).__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
        self.scaleD = scaleD
        self.alpha = alpha
        self.max_len = max_len
        '''self.Er = nn.Parameter(torch.randn(self.max_len, self.attention_head_size))
        nn.init.normal_(self.Er, std=0.02)'''
        self.add_Hrel = add_Hrel
        self.astro_sigmoid_nonlinearity = astro_sigmoid_nonlinearity

        self.relative_pos = RelativePositionalEncoding(self.attention_head_size, self.max_len, k=clip, method='linear')

        # Load weights from the pretrained self-attention if provided
        if pretrained_self_attention:
            self.query.weight.data.copy_(pretrained_self_attention.query.weight.data)
            self.query.bias.data.copy_(pretrained_self_attention.query.bias.data)
            self.key.weight.data.copy_(pretrained_self_attention.key.weight.data)
            self.key.bias.data.copy_(pretrained_self_attention.key.bias.data)
            self.value.weight.data.copy_(pretrained_self_attention.value.weight.data)
            self.value.bias.data.copy_(pretrained_self_attention.value.bias.data)
            self.dropout.p = pretrained_self_attention.dropout.p

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def phi_nonlinearity(self, x, gamma=100):
        elu = self.elu(x) + 1
        # rbf = torch.exp(-gamma * (x ** 2))
        return elu

    def skew(self, QEr):
        padded = nn.functional.pad(QEr, (1, 0))
        batch_size, num_heads, num_rows, num_cols = padded.shape
        reshaped = padded.reshape(batch_size, num_heads, num_cols, num_rows)
        Srel = reshaped[:, :, 1:, :]
        return Srel

    def forward(self,
                hidden_states: torch.Tensor,
                attention_mask: Optional[torch.FloatTensor] = None,
                head_mask: Optional[torch.FloatTensor] = None,
                encoder_hidden_states: Optional[torch.FloatTensor] = None,
                encoder_attention_mask: Optional[torch.FloatTensor] = None,
                past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
                output_attentions: Optional[bool] = False,
                current_segment=None
                ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query = self.transpose_for_scores(mixed_query_layer)
        key = self.transpose_for_scores(mixed_key_layer)
        value = self.transpose_for_scores(mixed_value_layer)

        batch_size, seq_len, embed_dim = hidden_states.size()
        permuted_attn_mask = attention_mask.permute(0, 1, 3, 2)

        if attention_mask is not None:
            # query = query.masked_fill(attention_mask == 0, 0)
            value = value.masked_fill(permuted_attn_mask == 0, 0)
            key = key.masked_fill(permuted_attn_mask == 0, 0)
        seq_len = key.shape[-2]
        head_dim = key.shape[-1]
        scaling_factor = self.scaleD * head_dim
        if seq_len > self.max_len:
            raise ValueError("sequence length exceeds model capacity")
        '''start = self.max_len - seq_len
        Er_t = self.Er[start:, :].transpose(0, 1)
        Er_t = self.phi_nonlinearity(Er_t)
        VEr = torch.matmul(Er_t, value)
        Hrel = self.skew(VEr)'''
        rel_pos = self.relative_pos(seq_len,
                                    current_segment=0)  # Current segment can't be easily passed to BertModel's formward() function
        R_t = self.phi_nonlinearity(rel_pos.transpose(0, 1))
        R_tV = torch.matmul(R_t, value)
        Hrel = R_tV

        alpha = self.alpha

        phi_key = self.phi_nonlinearity(key)
        phi_query = self.phi_nonlinearity(query)
        sum_phi_key = torch.sum(phi_key, dim=-2).reshape(batch_size, self.num_attention_heads, -1, phi_key.shape[-1])
        raw_normalizer = (sum_phi_key) @ torch.transpose(phi_query, -2, -1)
        scaled_raw_normalizer = (raw_normalizer) ** alpha
        normalizer = torch.diag_embed(scaled_raw_normalizer).reshape(batch_size, self.num_attention_heads, seq_len,
                                                                     seq_len)
        reciprocal_normalizer = torch.zeros_like(normalizer)
        non_zero_mask = normalizer != 0
        reciprocal_normalizer[non_zero_mask] = 1.0 / normalizer[non_zero_mask]
        if self.astro_sigmoid_nonlinearity:
            context_layer = reciprocal_normalizer @ (phi_query @ self.sigmoid(
                phi_key.transpose(-2, -1) @ value / scaling_factor + (Hrel / scaling_factor if self.add_Hrel else 0)))
        else:
            context_layer = reciprocal_normalizer @ (phi_query @ (
                    phi_key.transpose(-2, -1) @ value / scaling_factor + (
                Hrel / scaling_factor if self.add_Hrel else 0)))
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        context_layer = self.dropout(context_layer)

        outputs = (context_layer,)

        return outputs


class SoftmaxAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1, max_len=5000, clip=10):
        super(SoftmaxAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query_projection = nn.Linear(embed_dim, embed_dim)
        self.key_projection = nn.Linear(embed_dim, embed_dim)
        self.value_projection = nn.Linear(embed_dim, embed_dim)
        self.output_projection = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

        self.relative_pos = RelativePositionalEncoding(self.head_dim, max_len, k=clip, method=None)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, x, attention_mask=None, current_segment=None):
        batch_size, seq_len, embed_dim = x.size()
        query = self.query_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.key_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.value_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)

        # Add relative positional information
        rel_pos = self.relative_pos(seq_len, current_segment=current_segment)
        attention_scores = attention_scores + rel_pos.unsqueeze(0).unsqueeze(0)

        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
        context_layer = context_layer.view(*new_context_layer_shape)
        context_layer = self.output_projection(context_layer)
        return context_layer


class CustomEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, max_len=2048, type_vocab_size=1):
        super(CustomEmbedding, self).__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_embeddings = nn.Embedding(max_len, embed_dim)
        self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim)
        self.embedding_dim = embed_dim
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        if token_type_ids is not None:
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
        else:
            token_type_embeddings = 0

        embeddings = token_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


def plot_metrics(train_losses, val_losses, train_metrics, val_metrics, task, save_path):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_metrics, label='Training Metric')
    plt.plot(epochs, val_metrics, label='Validation Metric')
    plt.xlabel('Epochs')
    plt.ylabel(
        'Perplexity' if task == 'language_modeling' else 'Accuracy' if task == 'classification' or task == 'mrpc_qqp' else 'Loss')
    plt.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    # plt.show()


class AstroNormWithLayerNorm(nn.Module):
    def __init__(self, d_model, total_segments):
        super(AstroNormWithLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.total_segments = total_segments
        self.memory_retention_sum = None
        self.memory_retention = None

    def calculate_area(self, device):
        integer_x_values = torch.arange(0, self.total_segments + 1, device=device)
        current_area = torch.sum(self.memory_retention_function(integer_x_values, self.total_segments))
        return current_area

    def scaling_factor(self, device):
        return 1 / self.calculate_area(device)

    def memory_retention_function(self, x, x_max):
        return 0.049787 + torch.exp(-x / (x_max / 2)) - torch.exp(-(x / (x_max / 2) + 1))

    def scaled_memory_retention(self, x, x_max, device):
        k = self.scaling_factor(device)
        return k * self.memory_retention_function(x, x_max)

    def forward(self, mem, segment):
        device = mem.device

        if not isinstance(segment, torch.Tensor):
            segment = torch.tensor(segment, dtype=mem.dtype, device=device)

        if not mem.requires_grad:
            mem.requires_grad_(True)
        if not segment.requires_grad:
            segment.requires_grad_(True)

        memory_retention_factor = self.scaled_memory_retention(segment, self.total_segments, device)
        self.memory_retention = memory_retention_factor * mem

        self.layer_norm = self.layer_norm.to(device)

        if segment.item() == 0:
            self.memory_retention_sum = torch.zeros_like(mem, device=device)

        self.memory_retention_sum += self.memory_retention
        self.memory_retention = self.layer_norm(self.memory_retention)
        self.memory_retention_sum = self.layer_norm(self.memory_retention_sum)

        return None


def count_batches(dataloader):
    return sum(1 for _ in dataloader)


def calculate_accuracy(preds, labels):
    return accuracy_score(labels.cpu(), preds.cpu())


def calculate_perplexity(loss):
    return torch.exp(torch.tensor(loss))


def evaluate_stsb(preds, labels):
    pearson_corr = pearsonr(preds, labels)[0]
    spearman_corr = spearmanr(preds, labels)[0]
    return pearson_corr, spearman_corr


def evaluate_cola(preds, labels):
    return matthews_corrcoef(labels, preds)


def evaluate_mrpc_qqp(preds, labels):
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    return accuracy, f1


# kindly adapted from google-research/long-range-arena code
def create_learning_rate_scheduler(factors, config):
    """
      Creates learning rate schedule.
      Interprets factors in the factors string which can consist of:
      * constant: interpreted as the constant value,
      * linear_warmup: interpreted as linear warmup until warmup_steps,
      * rsqrt_decay: divide by square root of max(step, warmup_steps)
      * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
      * decay_every: Every k steps decay the learning rate by decay_factor.
      * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
      Args:
        factors: string, factors separated by '*' that defines the schedule.
        config:
            config.learning_rate: float, the starting constant for the lr schedule.
            config.warmup_steps: int, how many steps to warm up for in the warmup schedule.
            config.decay_factor: float, the amount to decay the learning rate by.
            config.steps_per_decay: int, how often to decay the learning rate.
            config.steps_per_cycle: int, steps per cycle when using cosine decay.
      Returns:
        a function of signature optimizer->(step->lr).
  """
    factors = [n.strip() for n in factors.split('*')]
    base_learning_rate: float = config.learning_rate
    warmup_steps: int = config.get('warmup_steps', 1000)
    decay_factor: float = config.get('decay_factor', 0.5)
    steps_per_decay: int = config.get('steps_per_decay', 20000)
    steps_per_cycle: int = config.get('steps_per_cycle', 100000)

    def step_fn(step):
        """ Step to learning rate function """
        ret = 1.0
        for name in factors:
            if name == 'constant':
                ret *= base_learning_rate
            elif name == 'linear_warmup':
                ret *= np.minimum(1.0, step / warmup_steps)
            elif name == 'rsqrt_decay':
                ret /= np.sqrt(np.maximum(step, warmup_steps))
            elif name == 'rsqrt_normalized_decay':
                ret *= np.sqrt(warmup_steps)
                ret /= np.sqrt(np.maximum(step, warmup_steps))
            elif name == 'decay_every':
                ret *= (decay_factor ** (step // steps_per_decay))
            elif name == 'cosine_decay':
                progress = np.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
                ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
            else:
                raise ValueError('Unknown factor %s.' % name)
        return ret

    return lambda optimizer: LambdaLR(optimizer, step_fn)


def create_pathfinder_splits(data_dir, diff_level, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):
    """Creates train/val/test splits for Pathfinder metadata files, with caching."""
    assert train_ratio + val_ratio + test_ratio == 1.0, "Split ratios must sum to 1.0"

    cache_file = os.path.join(data_dir, f"{diff_level}_metadata_splits.pkl")

    if os.path.exists(cache_file):
        with open(cache_file, "rb") as f:
            train_files, val_files, test_files = pickle.load(f)
        print(f"Loaded cached metadata splits from {cache_file}")
    else:
        metadata_files = sorted(glob(f"{data_dir}/{diff_level}/metadata/*.npy"))
        random.seed(random_seed)
        random.shuffle(metadata_files)

        num_files = len(metadata_files)
        train_size = int(train_ratio * num_files)
        val_size = int(val_ratio * num_files)

        train_files = metadata_files[:train_size]
        val_files = metadata_files[train_size:train_size + val_size]
        test_files = metadata_files[train_size + val_size:]

        with open(cache_file, "wb") as f:
            pickle.dump((train_files, val_files, test_files), f)
        print(f"Created and cached metadata splits at {cache_file}")

    return train_files, val_files, test_files


def parse_comma_separated_string(s: str) -> List[str]:
    """Parses a comma-separated string into a list of strings, stripping whitespace."""
    return [x.strip() for x in s.split(",")] if s else []
