import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.special import sph_harm
from sklearn.metrics import mean_squared_error
from torch.utils.data import Dataset, DataLoader

# Configuration
d_model = 8
nhead = 1  # 1 2 4 8
num_layers = 5
dim_feedforward = 8
max_pairs = 50000        # maximum number of (x,y) pairs in any sequence
batch_size = 100
num_epochs = 100
learning_rate = 5e-4
d = 3                  # dimension of x (always 3 for 2d spherical harmonics)
d_spherical = 10 #3

num_samples = 50000
num_pairs = 16

max_seq_len = num_pairs + 1

# Fixed weight vector w in R^d for y = w · x
#torch.manual_seed(42)
#w = torch.randn(d)

# Custom ReLU-based multi-head attention
class ReLUAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        assert d_model % nhead == 0
        self.d_k = d_model // nhead
        self.nhead = nhead
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, attn_mask=None):
        # x: (L, B, D)
        L, B, D = x.size()
        # project
        q = self.q_proj(x).view(L, B, self.nhead, self.d_k).transpose(0,2)  # (h, B, L, d_k)
        k = self.k_proj(x).view(L, B, self.nhead, self.d_k).transpose(0,2)
        v = self.v_proj(x).view(L, B, self.nhead, self.d_k).transpose(0,2)
        # scaled dot-product
        scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)  # (h, B, L, L)
        #scores = F.relu(scores)
        self.attn_scores = scores    # post-softmax

        # if attn_mask is not None:
        #     scores = scores.masked_fill(attn_mask.unsqueeze(0), float('-inf'))
        # # row-normalize
        # scores = scores - scores.max(dim=-1, keepdim=True)[0]
        weights = F.relu(scores)
        ##weights = F.softmax(scores, dim=-1)

        # # # store for inspection
        # self.last_raw_scores = scores.detach()     # post-ReLU, pre-softmax
        # self.last_weights    = weights.detach()    # post-softmax


        # attend
        context = torch.matmul(weights, v)  # (h, B, L, d_k)
        context = context.transpose(0,2).contiguous().view(L, B, D)
        return self.out_proj(context), self.attn_scores

# Custom ReLU-based multi-head attention
class SoftmaxAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        assert d_model % nhead == 0
        self.d_k = d_model // nhead
        self.nhead = nhead
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, attn_mask=None):
        # x: (L, B, D)
        L, B, D = x.size()
        # project
        q = self.q_proj(x).view(L, B, self.nhead, self.d_k).transpose(0,2)  # (h, B, L, d_k)
        k = self.k_proj(x).view(L, B, self.nhead, self.d_k).transpose(0,2)
        v = self.v_proj(x).view(L, B, self.nhead, self.d_k).transpose(0,2)
        # scaled dot-product
        scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)  # (h, B, L, L)
        #scores = F.softmax(scores, dim=-1)
        self.attn_scores = scores    # post-softmax

        # if attn_mask is not None:
        #     scores = scores.masked_fill(attn_mask.unsqueeze(0), float('-inf'))
        # # row-normalize
        # scores = scores - scores.max(dim=-1, keepdim=True)[0]
        weights = F.softmax(scores, dim=-1)

        # # # store for inspection
        # self.last_raw_scores = scores.detach()     # post-ReLU, pre-softmax
        # self.last_weights    = weights.detach()    # post-softmax

        # attend
        context = torch.matmul(weights, v)  # (h, B, L, d_k)
        context = context.transpose(0,2).contiguous().view(L, B, D)

        return self.out_proj(context), self.attn_scores

# Custom Transformer encoder layer using ReLUAttention
class CustomEncoderLayer_ReLU(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.self_attn = ReLUAttention(d_model, nhead)
        ## self.linear1 = nn.Linear(d_model, dim_feedforward)
        ## self.linear2 = nn.Linear(dim_feedforward, d_model)

        # 6-layer feed-forward: d_model->H, H->H, H->H, H->H, H->H, H->d_model
        layers = []
        layers.append(nn.Linear(d_model, dim_feedforward))
        for _ in range(4):
            layers.append(nn.Linear(dim_feedforward, dim_feedforward))
        layers.append(nn.Linear(dim_feedforward, d_model))
        self.ff_layers = nn.ModuleList(layers)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(0.0)
        self.dropout2 = nn.Dropout(0.0)

    def forward(self, src, src_mask=None):
        # src: (L, B, D)
        attn_out, attn_scores = self.self_attn(src, attn_mask=src_mask)
        #attn_out = self.self_attn(src, attn_mask=src_mask)
        src2 = self.dropout1(attn_out)
        src = self.norm1(src + src2)

        ## ff = self.linear2(F.relu(self.linear1(src)))
        ## src2 = self.dropout2(ff)
        ## src = self.norm2(src + src2)
        ## return src, attn_scores

        # 6-layer FFN sub-layer
        ff = src
        for i, layer in enumerate(self.ff_layers):
            ff = F.relu(layer(ff)) if i < len(self.ff_layers)-1 else layer(ff)
        src2 = self.dropout2(ff)
        return self.norm2(src + src2), attn_scores

# Custom Transformer encoder layer using ReLUAttention
class CustomEncoderLayer_Softmax(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.self_attn = SoftmaxAttention(d_model, nhead)
        ## self.linear1 = nn.Linear(d_model, dim_feedforward)
        ## self.linear2 = nn.Linear(dim_feedforward, d_model)

        # 6-layer feed-forward: d_model->H, H->H, H->H, H->H, H->H, H->d_model
        layers = []
        layers.append(nn.Linear(d_model, dim_feedforward))
        for _ in range(4):
            layers.append(nn.Linear(dim_feedforward, dim_feedforward))
        layers.append(nn.Linear(dim_feedforward, d_model))
        self.ff_layers = nn.ModuleList(layers)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(0.0)
        self.dropout2 = nn.Dropout(0.0)

    def forward(self, src, src_mask=None):
        # src: (L, B, D)
        attn_out, attn_scores = self.self_attn(src, attn_mask=src_mask)
        #attn_out = self.self_attn(src, attn_mask=src_mask)
        src2 = self.dropout1(attn_out)
        src = self.norm1(src + src2)

        ## ff = self.linear2(F.relu(self.linear1(src)))
        ## src2 = self.dropout2(ff)
        ## src = self.norm2(src + src2)
        ## return src, attn_scores

        # 6-layer FFN sub-layer
        ff = src
        for i, layer in enumerate(self.ff_layers):
            ff = F.relu(layer(ff)) if i < len(self.ff_layers)-1 else layer(ff)
        src2 = self.dropout2(ff)
        return self.norm2(src + src2), attn_scores


# Decoder-only Transformer Model with shared pair positions and ReLUAttention
class DecoderOnlyTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, max_pairs, max_seq_len, dim_x):
        super().__init__()
        self.d_model = d_model
        self.max_pairs = max_pairs
        self.max_seq_len = max_seq_len
        self.concat_dim = dim_x + 1
        self.input_proj = nn.Linear(self.concat_dim, d_model)
        self.x_emb = nn.Linear(dim_x, d_model)
        self.y_emb = nn.Linear(1, d_model)
        self.type_emb = nn.Embedding(2, d_model)

        # # sinusoidal positional table
        # pe = torch.zeros(max_seq_len, d_model)
        # position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # pe[:, 0::2] = torch.sin(position * div_term)
        # pe[:, 1::2] = torch.cos(position * div_term)
        # self.register_buffer('pos_table', pe)

        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        div_term = math.pi/(2*(max_seq_len - 1))
        pe[:, -1:] = torch.sin(position * div_term)
        pe[:, -2:-1] = torch.cos(position * div_term)
        self.register_buffer('pos_table', pe)

        # build layers: first num_layers-1 with ReLUAttention, last with softmax attention
        self.layers = nn.ModuleList()
        for _ in range(num_layers-1):
            self.layers.append(
                CustomEncoderLayer_ReLU(d_model, nhead, dim_feedforward)
            )
        # # final layer: standard softmax attention
        # self.layers.append(
        #     nn.TransformerEncoderLayer(
        #         d_model=d_model,
        #         nhead=nhead,
        #         dim_feedforward=dim_feedforward,
        #         dropout=0.1,
        #         activation='relu',
        #         batch_first=False
        #     )
        # )

        self.layers.append(
            CustomEncoderLayer_Softmax(d_model, nhead, dim_feedforward)
        )

        self.head = nn.Linear(d_model, 1)

    def forward(self, xs, ys):
        # xs: (B, n, dim_x), ys: (B, n-1,1)
        B, n, _ = xs.size()
        num_pairs = n - 1
        seq_len = num_pairs + 1
        # # embeddings
        # x_emb = self.x_emb(xs)  # (B,n,D)
        # y_pad = torch.cat([ys, torch.zeros(B,1,1, device=xs.device)], dim=1)
        # y_emb = self.y_emb(y_pad)  # (B,n,D)

        pad_ys = torch.cat([ys, torch.zeros(B,1,1,device=xs.device)], dim=1)
        concat = torch.cat([xs, pad_ys], dim=-1)  # (B, n, d+1)
        embedded = self.input_proj(concat)  # (B, n, d_model)
        seq = torch.zeros(B, self.max_seq_len, self.d_model, device=xs.device) # (B, 2*n+1, d_model)
        seq[:,:seq_len,:] = embedded
        pos_emb = self.pos_table[:max_seq_len]
        type_ids = torch.zeros(max_seq_len, dtype=torch.long, device=xs.device)
        type_emb = self.type_emb(type_ids)

        # seq = torch.zeros(B, seq_len, self.d_model, device=xs.device)
        # seq[:,0:2*num_pairs:2,:] = x_emb[:,:num_pairs,:]
        # seq[:,1:2*num_pairs:2,:] = y_emb[:,:num_pairs,:]
        # seq[:,-1,:] = x_emb[:,-1,:]
        # # position & type & task embeddings
        # pos_ids = torch.tensor([i for i in range(1,num_pairs+1) for _ in (0,1)] + [num_pairs+1], device=xs.device)
        # pos_emb = self.pos_table[pos_ids]  # (seq_len,D)
        # type_ids = torch.tensor([0 if i%2==0 else 1 for i in range(2*num_pairs)] + [0], device=xs.device)
        # type_emb = self.type_emb(type_ids)

        # sum
        seq = seq + pos_emb.unsqueeze(0) + type_emb.unsqueeze(0) ##+ task_emb

        # causal mask
        ##mask = torch.triu(torch.ones(max_seq_len, max_seq_len, device=xs.device), diagonal=1).bool()
        window_size = max_seq_len  # or any other value you choose
        mask = generate_window_mask(self.max_seq_len, window_size, xs.device)

        # transformer expects (L,B,D)
        out = seq.transpose(0,1)
        for layer in self.layers:
            out, attn_scores = layer(out, src_mask=mask)
        # predict last
        last = out[seq_len-1]  # (B,D)
        return self.head(last).squeeze(-1), attn_scores


# The rest (dataset, collate_fn, training loop) remains unchanged, plugging this model class into it.
class SyntheticPairDataset(Dataset):
    """
    Generates sequences where x_i are real spherical-harmonic feature vectors and y = w·x.
    Each example: (x1...x_n, y1...y_{n-1}, y_n).
    """
    def __init__(self, num_samples, num_pairs, dim):
        self.num_samples = num_samples
        self.num_pairs = num_pairs  # n - 1 pairs, plus one x_n
        self.dim = dim
        # Precompute (l,m) pairs to fill dim spherical harmonics
        pairs = []
        l = 0
        while len(pairs) < dim:
            for m in range(-l, l+1):
                pairs.append((l, m))
                if len(pairs) >= dim:
                    break
            l += 1
        self.harmonic_pairs = pairs[:dim]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # For each sequence position i=0..n-1, sample random (theta,phi)
        n = self.num_pairs + 1
        theta = np.random.rand(n) * math.pi     # colatitude in [0,pi]
        phi   = np.random.rand(n) * 2*math.pi   # longitude in [0,2pi]
        xxs = np.stack([np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)], axis = 1)
        # xxs = np.random.randn(n, 3)
        # xxs /= np.linalg.norm(xxs, axis=1, keepdims=True)
        xxs = torch.from_numpy(xxs).float()
        #print(xxs.shape)
        # Compute real spherical harmonics features
        feats = []
        for (l, m) in self.harmonic_pairs:
            Y_lm = sph_harm(m, l, phi, theta)   # complex array shape (n,)
            feats.append(Y_lm.real)
        x = np.stack(feats, axis=1)            # (n, dim)
        xs = torch.from_numpy(x).float()       # (n, dim)

        # Compute y values: y_i = w · x_i
        w = torch.rand(self.dim)
        w = F.normalize(w, p=2, dim=0)
        #print(xs.shape, w.shape)
        ys = xs[:-1].matmul(w)                 # (n-1,)
        y_n = xs[-1:].matmul(w)                # (1,)
        return xxs, ys.unsqueeze(-1), y_n.unsqueeze(-1)

# Collate to batch tensors
def collate_fn(batch):
    xs_batch = torch.stack([item[0] for item in batch], dim=0)   # (B, n, dim)
    ys_batch = torch.stack([item[1] for item in batch], dim=0)   # (B, n-1,1)
    y_n_batch = torch.stack([item[2] for item in batch], dim=0)  # (B,1,1)
    return xs_batch, ys_batch, y_n_batch.squeeze(-1)

def generate_window_mask(seq_len, window_size, device):
    mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
    for i in range(seq_len):
        start = max(0, i - window_size)
        mask[i, start:i+1] = 0  # allow attention to current and past tokens in window
    return mask

# Instantiate model, optimizer, loss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DecoderOnlyTransformer(d_model, nhead, num_layers, dim_feedforward, max_pairs, max_seq_len, d).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# DataLoaders: train and validation
dataset = SyntheticPairDataset(num_samples = num_samples, num_pairs = num_pairs, dim = d_spherical)  # dim: # of spherical harmonics
val_dataset = SyntheticPairDataset(num_samples = num_samples, num_pairs = num_pairs, dim = d_spherical)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=False, collate_fn=collate_fn)

# Training loop with epoch-average training loss
for epoch in range(1, num_epochs+1):
    model.train()
    total_train_loss = 0.0
    for xs, ys, y_n in dataloader:
        xs, ys, y_n = xs.to(device), ys.to(device), y_n.to(device)
        optimizer.zero_grad()
        preds, _ = model(xs, ys)
        loss = criterion(preds, y_n.squeeze(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item() * xs.size(0)
    avg_train_loss = total_train_loss / len(dataset)

    # Validation loss
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for xs_val, ys_val, y_n_val in val_loader:
            xs_val, ys_val, y_n_val = xs_val.to(device), ys_val.to(device), y_n_val.to(device)
            preds_val, _ = model(xs_val, ys_val)
            total_val_loss += criterion(preds_val, y_n_val.squeeze(-1)).item() * xs_val.size(0)
    avg_val_loss = total_val_loss / len(val_dataset)

    print(f"Epoch {epoch:3d} | Train Loss = {avg_train_loss:.3e} | Val Loss = {avg_val_loss:.3e}")


# Quick test on a validation batch
model.eval()

val_dataset = SyntheticPairDataset(num_samples = 1*num_samples, num_pairs = num_pairs, dim = d_spherical)
val_loader = DataLoader(val_dataset, batch_size = 1*num_samples, shuffle=False, collate_fn=collate_fn)

with torch.no_grad():
    xs, ys, y_n = next(iter(val_loader))
    xs, ys = xs.to(device), ys.to(device)
    print(xs.shape, ys.shape)
    pred, attn_scores = model(xs, ys)
    print("Ground truth y_n:", y_n.squeeze(-1)[:5])
    print("Predictions    :", pred[:5])
    print(mean_squared_error(y_n.squeeze(-1), pred))
    print(np.linalg.norm(pred - y_n.squeeze(-1))/np.linalg.norm(y_n.squeeze(-1)))



import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
import statistics

print(xs.shape, attn_scores.shape)

total_pearson_corr = []
total_pearson_pvalue = []

for _ in range(5000):
    idx = np.random.randint(50000, size=1)
    idx = idx[0]
    #idx = 2000
    print(idx)

    attn_scores_x = attn_scores[0,idx,:,:]
    #print(attn_scores_x.shape)
    # print(attn_scores_x)
    e_x = np.exp(attn_scores_x)
    # print(e_x / e_x.sum(axis=1, keepdims=True))
    attn_dist = e_x / e_x.sum(axis=1, keepdims=True)

    ##exp_norms = np.exp(-(np.matmul(xs[idx,-1,:] - xs[idx,:-1,:], w[idx,:]))**2)
    exp_norms = np.exp(-(np.linalg.norm(xs[idx,-1,:] - xs[idx,:-1,:], axis=1))**2)
    # print(xs[idx,:,:].T)
    exp_dist = exp_norms/exp_norms.sum()
    #print(exp_norms/exp_norms.sum())

    pearson_corr, pearson_pvalue = pearsonr(exp_dist, attn_dist[-1,:-1])
    if pearson_corr < 0.0:
        continue

    print(pearson_corr, pearson_pvalue)
    total_pearson_corr.append(pearson_corr)
    total_pearson_pvalue.append(pearson_pvalue)

print(len(total_pearson_corr))
mean_value = statistics.mean(total_pearson_corr)
print(f"Mean: {mean_value}")

# Calculate the sample standard deviation (using n-1 in the denominator)
stdev_sample = statistics.stdev(total_pearson_corr)
print(f"Sample Standard Deviation: {stdev_sample}")

mean_value = statistics.mean(total_pearson_pvalue)
print(f"Mean: {mean_value}")

# Calculate the sample standard deviation (using n-1 in the denominator)
stdev_sample = statistics.stdev(total_pearson_pvalue)
print(f"Sample Standard Deviation: {stdev_sample}")


plt.plot(sorted(attn_dist[-1,:-1], reverse=True), '-o', label = 'softmax attention score', color = 'blue', linewidth=2)
plt.plot([x for _,x in sorted(zip(-attn_dist[-1,:-1], exp_dist))], '-o', label = 'kernel function', color = 'red', linewidth=2)
plt.xlabel("sorted index", fontsize=15)
plt.ylabel("scores", fontsize=15)

plt.legend(loc='upper right', fontsize=12)
plt.title(f"Pearson correlation = % .3f" %pearson_corr, fontsize=15)
plt.savefig("plot_scores.png")
plt.show()



plt.hist(total_pearson_corr, bins=30)
plt.xlabel("Pearson correlation", fontsize=15)
plt.ylabel("counts", fontsize=15)
plt.savefig("plot_hist.png")
plt.show()



