import torch
import numpy as np
from datasets import load_dataset
from modelscope.msdatasets import MsDataset
# from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import pandas as pd
from scipy.stats import linregress

# Constants
LAYERS = 32
NEURONS = 14336

def tokenize_function(examples, tokenizer):
    """
    Tokenize texts and prepare them for the model.
    
    Args:
        examples: Examples from the dataset
        tokenizer: Tokenizer to use
        
    Returns:
        tokenized: Tokenized examples
    """
    # Tokenize the texts and prepare them for the model
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024, return_tensors="pt")

def prepare_tokenized_dataset(dataset_path, dataset_name, tokenizer):
    """
    Load, tokenize, and prepare a dataset.
    
    Args:
        dataset_path: Path to the dataset
        dataset_name: Name of the dataset
        tokenizer: Tokenizer to use
        
    Returns:
        dataloader: DataLoader for the tokenized dataset
    """
    if dataset_name == "openwebtext":
        # load iterable dataset
        dataset = MsDataset.load(dataset_path, streaming=True, split="train")
        
        # Map the tokenize function over the dataset
        tokenized_dataset = dataset.map(
            lambda examples: tokenize_function(examples, tokenizer),
            batched=True
        )
        
        # Set the format to PyTorch tensors
        tokenized_dataset = tokenized_dataset.with_format("torch")
    elif dataset_name in ["medical_o1_reasoning_SFT_en", "medical_o1_reasoning_SFT_zh", "medical_o1_reasoning_SFT_en_val","medical_o1_reasoning_SFT_en_explain"]:
        if dataset_name == "medical_o1_reasoning_SFT_en":
            dataset = MsDataset.load(dataset_path, "en", split="train[0%:90%]")
        elif dataset_name == "medical_o1_reasoning_SFT_zh":
            dataset = MsDataset.load(dataset_path, "zh", split="train[0%:90%]")
        elif dataset_name == "medical_o1_reasoning_SFT_en_val":
            dataset = MsDataset.load(dataset_path, "en", split="train[90%:100%]")
        elif dataset_name == "medical_o1_reasoning_SFT_en_explain":
            dataset = MsDataset.load(dataset_path, "en", split="train[30%:70%]")
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

            

        # combine all the columns into a single column
        dataset = dataset.map(
            lambda x: {
                "text": f"Question: {x['Question']}\nReasoning: {x['Complex_CoT']}\nResponse: {x['Response']}"
            }
        )
        
        # Map the tokenize function over the dataset
        tokenized_dataset = dataset.map(
            lambda examples: tokenize_function(examples, tokenizer),
            batched=True
        )
        
        # Set the format to PyTorch tensors
        tokenized_dataset = tokenized_dataset.with_format("torch")
    else:
        print(f"Unknown dataset: {dataset_name}")
        return None
    
    # Create a batched dataloader
    batch_size = 16
    dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=batch_size)
    return dataloader

# This function is used to calculate the frequency of tokens in a dataset. And verify Zipf's law on it.
def calculate_token_freq(dataloader: torch.utils.data.DataLoader, tokenizer):
    # token_freq = pd.DataFrame(columns=["token", "token_id", "freqency"])
    token_freq = {}
    for batch in dataloader:
        token_ids = batch['input_ids'].reshape(-1)
        for token_id in token_ids:
            token_id = int(token_id)
            token_freq[token_id] = token_freq.get(token_id, 0) + 1

    df = pd.DataFrame(list(token_freq.items()), columns=["token_id", "frequency"])
    token_ids = df['token_id'].tolist()
    token_contents = [tokenizer.decode(token_id) for token_id in token_ids]
    df['token'] = token_contents

    df['rank'] = df['frequency'].rank(ascending=False, method='average').astype(int)
    df.sort_values(by='rank', ascending=True, inplace=True)
    return df

# apply Zipf's law to the token frequency
def apply_zipf_law(token_freq: pd.DataFrame):
    # fit a line to log of f and r
    log_f = np.log10(token_freq['frequency'])
    log_r = np.log10(token_freq['rank'])
    slope, intercept, r_value, p_value, std_err = linregress(log_r, log_f)
    return slope, intercept, r_value, p_value, std_err

def process_batch(model, data):
    """
    Process a batch of data through the model.
    
    Args:
        model: The transformer model
        data: Batch of data
        
    Returns:
        None. The model processes the data and hooks capture the activations.
    """
    if isinstance(data, dict) and "input_ids" in data:
        input_ids = data["input_ids"].to(next(model.parameters()).device)
        # Create attention mask if available
        attention_mask = None
        if "attention_mask" in data:
            attention_mask = data["attention_mask"].to(next(model.parameters()).device)
        
        # Forward pass with the appropriate inputs
        with torch.no_grad():  # Add no_grad to reduce memory usage
            model(input_ids=input_ids, attention_mask=attention_mask)
    else:
        print(f"Unexpected data format: {type(data)}")
        print(f"Data keys: {data.keys() if isinstance(data, dict) else 'Not a dictionary'}")
        raise ValueError("Data format is not as expected. Check the dataset structure.")
    

# check if the partition is valid
def check_partition(partitions):
    # check if the partition is valid
    # by checking 1. they do not intersect 2. they cover all the nodes
    # concat all the partitions
    all_partition = [int(node) for partition in partitions for node in partition]
    if len(set(all_partition)) != NEURONS:
        return False
    # check if all nodes in within range [1,NEURONS]
    for node in all_partition:
        if node < 1 or node > NEURONS:
            return False
    return True

def load_partition_result(partition_result_path, index_from = 1):
    """
    Load the partition result from a file.

    Args:

        partition_result_path: Path to the partition result file
        index_from: The min index in the output list (default is 1)

    Returns: List of List of neuron ids (indexed from 1)
    """
    with open(partition_result_path, 'r') as f:
        partition_result = f.readlines()
        partitions = [[int(node) for node in line.split()[1:]] for line in partition_result]
    # delete list with 0 length
    partitions = [partition for partition in partitions if len(partition) > 0]
    # # check if the partition is valid
    # if not check_partition(partitions):
    #     raise ValueError("Partition is not valid")
    
    if index_from == 0:
        return [[ind -1 for ind in partition] for partition in partitions]
    else:
        return partitions

def load_coactivation_graph(file_path, edge_number_limit = None):
    """
    Load the coactivation graph from a file.
    Args:
        file_path: Path to the coactivation graph file
        edge_number_limit: Limit the number of edges to load
    Returns:
        if file_path is .hgr file
            coactivation_graph: Coactivation graph
            edge_degree: the number of vertices in each hyperedge
        elif file_path is .jsonl file
            coactivation_dataframe: Coactivation dataframe
            edge_degree: the number of vertices in each hyperedge

    """
    if file_path.endswith('.hgr'):
        edge_degree = []
        coactivation_graph = []

        with open(file_path, 'r') as f:
            for line in f:
                if edge_number_limit is not None and len(coactivation_graph) >= edge_number_limit:
                    break
                edge_degree.append(len(line))
                coactivation_graph.append([int(node) for node in line.split()])

        edge_degree = np.array(edge_degree)
        return coactivation_graph, edge_degree
    elif file_path.endswith('.jsonl'):
        coactivation_dataframe = pd.read_json(file_path, lines=True)
        # delete rows where length of activated_neurons is less than edge_number_limit
        if edge_number_limit is not None:
            coactivation_dataframe = coactivation_dataframe[coactivation_dataframe['activated_neurons'].apply(len) >= edge_number_limit]
        edge_degree = coactivation_dataframe['activated_neurons'].apply(len)
        return coactivation_dataframe, edge_degree
    else:
        raise ValueError(f"Unknown file type: {file_path}")
