from transformers import (
    AutoModelForCausalLM, 
    AutoConfig, 
    AutoTokenizer, 
    HfArgumentParser,
    DataCollatorForLanguageModeling
)
from typing import List, Optional, Tuple, Union, Dict
import torch
import torch.nn as nn
from itertools import chain
from datasets import load_dataset, load_from_disk, concatenate_datasets
from args import TrainingArguments, ModelArguments
import numpy as np
from tqdm import tqdm

# from 


import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def parse_hf_args():
    parser = HfArgumentParser((ModelArguments, TrainingArguments))
    args, training_args, _ = parser.parse_args_into_dataclasses(
        return_remaining_strings=True)

    return args, training_args


def process_datasets(dataset, train_num_data, tokenizer):
    '''
    We divided the proportions of RedPajamaCommonCrawl, RedPajamaArXiv, 
    and RedPajamaBook by a normalization value because the data length 
    in these domains is higher than in other domains.
    '''
    proportions = {
        "RedPajamaC4": 0.492,
        "RedPajamaStackExchange": 0.01,
        "RedPajamaCommonCrawl": 0.361 / 3,
        "RedPajamaGithub": 0.008,
        "RedPajamaWikipedia": 0.031,
        "RedPajamaArXiv": 0.007 / 20,
        "RedPajamaBook": 0.091 / 200
    }
    
    filtered_datasets = {
        name: dataset.filter(lambda x: x['meta'] == {"redpajama_set_name": f"{name}"})
        for name in proportions.keys()
    }
    
    test_datasets = []
    train_datasets = []

    for name, proportion in proportions.items():   
        split = filtered_datasets[name].train_test_split(test_size=(300 * proportion) / len(filtered_datasets[name]))
        test_datasets.append(split['test'])
        train_split = split['train'].train_test_split(test_size=1-(train_num_data * proportion) / len(split['train']))['train']
        train_datasets.append(train_split)

    dataset, test_dataset = concatenate_datasets(train_datasets), concatenate_datasets(test_datasets)
    
    tokenizer.pad_token = tokenizer.eos_token
    
    column_names = dataset.column_names
    text_column_name = "text" if "text" in column_names else column_names[0]
    
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name])

    dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )
    
    test_dataset = test_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )
    
    block_size = 512
    def group_texts(examples):
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result


    dataset = dataset.map(
        group_texts,
        batched=True,
        desc=f"Grouping texts in chunks of {block_size}",
    )

    test_dataset = test_dataset.map(
        group_texts,
        batched=True,
        desc=f"Grouping texts in chunks of {block_size}",
    )
    
    return dataset, test_dataset


@torch.no_grad()
def get_hidden_state_vectors(model, dataset, num_data, device, save_path="hidden_states.npy"):
    model = model.to(device)
    model.eval()

    data_index = torch.randperm(len(dataset))[:num_data].tolist()
    hidden_states_list = []

    for i in tqdm(data_index, desc="Collecting hidden states"):
        input_ids = torch.tensor(dataset[i]['input_ids']).reshape(1, -1).to(device)

        hidden_states = model(input_ids, output_hidden_states=True).hidden_states
        pooled = [layer.mean(dim=1).squeeze(0).cpu() for layer in hidden_states]  # (hidden_dim,) per layer

        hidden_states_list.append(torch.stack(pooled))  # (num_layers, hidden_dim)

        del input_ids, hidden_states

    tensor = torch.stack(hidden_states_list)  # (num_data, num_layers, hidden_dim)

    # np.save(save_path, tensor.numpy())
    # print(f"✅ Saved hidden states to {save_path}")

    return tensor

args, training_args = parse_hf_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True)
# config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

# Load  and tokenize dataset
dataset = load_dataset('DKYoon/SlimPajama-6B')['train']
dataset, test_dataset = process_datasets(dataset, training_args.train_num_data, tokenizer)

# print('start')

all_hidden_layer_embeddings = get_hidden_state_vectors(model, dataset, 100, device)

# print('end')

all_hidden_layer_embeddings = np.transpose(all_hidden_layer_embeddings.numpy(), (1, 0, 2))

# print(all_hidden_layer_embeddings.shape)



# for manifold formation from feature matrix using PGM

from sagman_utils.utils import spade,hnsw,construct_adj, spectral_embedding_eig,SPF,construct_weighted_adj,spade_nonetworkx

import numpy as np
from scipy.sparse import csr_matrix

def to_unweighted_csr(adj_matrix_csr):
    # Get the indices and indptr from the original matrix
    indices = adj_matrix_csr.indices
    indptr = adj_matrix_csr.indptr
    # Create an array of 1s for the values (all edges have weight 1)
    unweighted_data = np.ones(len(indices), dtype=np.float64)
    # Create the unweighted adjacency matrix in CSR format
    unweighted_adj_matrix = csr_matrix((unweighted_data, indices, indptr), shape=adj_matrix_csr.shape)

    return unweighted_adj_matrix

def feature_matrix_to_adj_mat(feature_mat, k = 20, l = 3):
    neighs, distance = hnsw(feature_mat, k = k)
    embed_adj_mtx = construct_weighted_adj(neighs, distance)# construct_weighted_adj,construct_adj
    embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, l)
    # if you want the unweighted embed adj mtx, run the following
    embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)
    return embed_adj_mtx.toarray()



def calculate_eigenvalues_sum(a, b):
    # Calculate m = a^-1 * b
    m = np.linalg.inv(a) @ b
    
    # Calculate eigenvalues of m
    eigenvalues, _ = np.linalg.eig(m)
    
    # Take log of eigenvalues, square them, and sum
    eigenvalues_sum = np.sum(np.log(eigenvalues)**2)
    
    return eigenvalues_sum

def calculate_eigenvalues(a, b):
    # Calculate m = a^-1 * b
    m = np.linalg.inv(a) @ b
    
    # Calculate eigenvalues of m
    eigenvalues, _ = np.linalg.eig(m)
    
    # Return only the real parts
    return np.real(eigenvalues)


the_k = 20
level = 3
all_adj_matrix = []

for i in range(len(all_hidden_layer_embeddings)):
    adj_matrix = feature_matrix_to_adj_mat(all_hidden_layer_embeddings[i], k=the_k, l = level)
    all_adj_matrix.append(adj_matrix)


# computes all the RiemMaps value between different layers

def compute_consecutive_riem(all_adj_matrix, alpha=0.001, interval=1):
    """
    Compute Riem distances between consecutive adjacency matrices.

    Parameters:
    - all_adj_matrix: list of np.ndarray, each of shape (N, N)
    - alpha: float, small value added to diagonal for normalization

    Returns:
    - riem_distances: list of float, length (len(all_adj_matrix) - 1)
    """
    num_graphs = len(all_adj_matrix)
    riem_distances = []

    shape = all_adj_matrix[0].shape  # Assumes all matrices have the same shape
    diagonal_matrix = np.diag([alpha] * min(shape))

    for i in tqdm(range(num_graphs - interval)):
        # Compute Laplacians
        degree_matrix_1 = np.diag(np.sum(all_adj_matrix[i], axis=1))
        lap1 = degree_matrix_1 - all_adj_matrix[i]
        mod_lap1 = lap1 + diagonal_matrix  # Modified Laplacian

        degree_matrix_2 = np.diag(np.sum(all_adj_matrix[i + interval], axis=1))
        lap2 = degree_matrix_2 - all_adj_matrix[i + interval]
        mod_lap2 = lap2 + diagonal_matrix  # Modified Laplacian

        # Compute Riem distance
        riem_value = calculate_eigenvalues_sum(mod_lap1, mod_lap2)
        riem_distances.append(riem_value)

    return riem_distances


riems = compute_consecutive_riem(all_adj_matrix, alpha = 0.001, interval = training_args.layer_intervals)
# print(riems)
# print(len(riems))
print("Best layer ID: ")
print(riems.index(min(riems)))