import matplotlib
from matplotlib import rc
from matplotlib import pyplot as plt
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy import sparse
from sentence_transformers import SentenceTransformer
import scipy.stats as stats

def plot_data(data, ylabel, xlabel, file_name):
    matplotlib.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
    matplotlib.rc('text', usetex=True)
    matplotlib.rc('text', usetex=True)
    matplotlib.rcParams['text.latex.preamble'] = r"\usepackage{amsmath}"
    matplotlib.rc('font', family='sans-serif', size=20)

    xsize = 24
    tsize = 24
    lsize = 20
    figsize = (7.5, 5.8)
    clrs = [
        "#d62728",  # Brick red
        "#1f77b4",  # Muted blue
        '#FFA500',  # Yellow
        "#ff7f0e",  # Safety orange
        "#2ca02c",  # Cooked asparagus green
        "#9467bd",  # Muted purple
        "#8c564b",  # Chestnut brown
        "#e377c2",  # Raspberry yogurt pink
        "#7f7f7f",  # Middle gray
        "#bcbd22",  # Curry yellow-green
        "#17becf"  # Blue-teal
    ]
    clr_st = ['brown', 'gold', 'lime', 'p', 'k']
    mrk = ['o', 's', '^', '*']

    plt.figure(figsize=figsize)
    plt.grid()
    for i in range(len(data)):
        plt.plot(data[i][0], data[i][1], linewidth=3, label=data[i][2], color=clrs[i], marker=mrk[i], markersize=12)

    plt.xticks(fontsize=tsize, usetex=True, fontname="Times")
    plt.yticks(fontsize=tsize, usetex=True, fontname="Times")
    plt.xlabel(xlabel, fontsize=xsize, usetex=True)
    plt.ylabel(ylabel, fontsize=xsize, usetex=True)
    plt.legend(fontsize=lsize, loc='best', frameon=True, fancybox=True, framealpha=0.8, edgecolor='k')
    plt.tight_layout()
    plt.savefig(f'./Plots/{file_name}.pdf', dpi=None, facecolor='w', format='pdf', bbox_inches="tight")

def load_imdb(n, d, seed=42):
    np.random.seed(seed)
    dataset = load_dataset("imdb", ignore_verifications=True)['train']
    random_indices = np.random.choice(len(dataset), size=n, replace=False)
    random_subset = dataset.select(random_indices)

    texts = random_subset['text']
    b_matrix = np.array(random_subset['label'], dtype=np.float64).reshape(-1, 1)

    vectorizer = TfidfVectorizer()
    X_tfidf = vectorizer.fit_transform(texts)

    sum_tfidf = np.sum(X_tfidf, axis=0)
    sum_tfidf = np.array(sum_tfidf).flatten()
    top_d_indices = np.argsort(sum_tfidf).flatten()[::-1][:d]
    X_reduced = X_tfidf[:, top_d_indices]

    A_matrix = X_reduced.toarray()
    return torch.tensor(A_matrix).to('cuda'), torch.tensor(b_matrix).to('cuda')


def load_android_app(n, d, seed=42):
    np.random.seed(seed)
    dataset = load_dataset("app_reviews", ignore_verifications=True)['train']
    random_indices = np.random.choice(len(dataset), size=n, replace=False)
    random_subset = dataset.select(random_indices)

    texts = random_subset['review']
    b_matrix = np.array(random_subset['star'], dtype=np.float64).reshape(-1, 1)

    vectorizer = TfidfVectorizer()
    X_tfidf = vectorizer.fit_transform(texts)
    sum_tfidf = np.sum(X_tfidf, axis=0)
    sum_tfidf = np.array(sum_tfidf).flatten()
    top_d_indices = np.argsort(sum_tfidf).flatten()[::-1][:d]
    X_reduced = X_tfidf[:, top_d_indices]

    A_matrix = X_reduced.toarray()
    return torch.tensor(A_matrix).to('cuda'), torch.tensor(b_matrix).to('cuda')

def load_android_app_transformer(n, top_k=300, seed=42):
    np.random.seed(seed)
    dataset = load_dataset("app_reviews", split='train', ignore_verifications=True)
    random_indices = np.random.choice(len(dataset), size=n, replace=False)
    random_subset = dataset.select(random_indices)
    texts = random_subset['review']
    b_matrix = np.array(random_subset['star'], dtype=np.float64).reshape(-1, 1)
    
    model_id = 'naver/splade-cocondenser-ensembledistil'
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForMaskedLM.from_pretrained(model_id)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    def generate_sparse_embedding(text, model, tokenizer, device):
        tokens = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        tokens = {k: v.to(device) for k, v in tokens.items()}
        with torch.no_grad():
            output = model(**tokens)
        
        logits = output.logits
        sparse_embedding = torch.max(
            torch.log(1 + torch.relu(logits)) * tokens['attention_mask'].unsqueeze(-1),
            dim=1
        )[0].squeeze()
        
        return sparse_embedding

    sparse_embeddings = []
    for text in texts:
        sparse_embedding = generate_sparse_embedding(text, model, tokenizer, device)
        sparse_embeddings.append(sparse_embedding.cpu().numpy())
    
    sparse_embeddings = np.array(sparse_embeddings)
    
    def keep_top_k_features_based_on_norm(embeddings, k):
        feature_norms = np.linalg.norm(embeddings, axis=0)
        top_k_indices = np.argsort(-feature_norms)[:k]
        reduced_embeddings = embeddings[:, top_k_indices]
        return reduced_embeddings

    reduced_sparse_embeddings = keep_top_k_features_based_on_norm(sparse_embeddings, top_k)
    
    sparse_embeddings = torch.tensor(reduced_sparse_embeddings, device=device)
    b_matrix = torch.tensor(b_matrix, device=device)
    
    return sparse_embeddings, b_matrix

def create_matrix_pair(n, d, mean=0, std=1, random_seed=42):
    np.random.seed(random_seed)
    low, upp = mean - std, mean + std
    X = stats.truncnorm((low - mean) / std, (upp - mean) / std, loc=mean, scale=std)

    matA = X.rvs((n, d))
    matB = X.rvs((n, d))

    return matA, matB


def create_matrix_pair_with_outlier(n, d, mean=0, std=1, random_seed=42, mean_outlier=10, std_outlier=5,
                                    outlier_fraction=0.1, zero_row_fraction=0.1):
    np.random.seed(random_seed)
    matA, matB = create_matrix_pair(n, d, mean, std, random_seed)

    low, upp = mean_outlier - std_outlier, mean_outlier + std_outlier
    X2 = stats.truncnorm((low - mean_outlier) / std_outlier, (upp - mean_outlier) / std_outlier, loc=mean_outlier,
                         scale=std_outlier)

    outlier_size = round(n * d * outlier_fraction)
    outlier_indices = np.random.choice(n * d, size=outlier_size, replace=False)

    outliersA = X2.rvs(outlier_size)
    outliersB = X2.rvs(outlier_size)

    matA_flat = matA.flatten()
    matB_flat = matB.flatten()

    np.put(matA_flat, outlier_indices, outliersA)
    np.put(matB_flat, outlier_indices, outliersB)
    print(matA.shape, matB.shape)

    zero_idx_A = np.random.choice(n * d, int(n * d * zero_row_fraction), replace=False)
    for idx in zero_idx_A:
        matA_flat[idx] = 0
    zero_idx_B = np.random.choice(n * d, int(n * d * zero_row_fraction), replace=False)
    for idx in zero_idx_B:
        matB_flat[idx] = 0

    matA = matA_flat.reshape(n, d)
    matB = matB_flat.reshape(n, d)

    return torch.tensor(matA).to('cuda'), torch.tensor(matB).to('cuda')