# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division

import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, load_from_disk
import torch
import pandas as pd
import  os

def get_dataset(dataset_name: str, tokenizer) -> torch.Tensor:
    if dataset_name == "wikitext-2":
        # Load dataset directly online
        dataset1 = load_dataset("wikitext", "wikitext-2-raw-v1")
        test = dataset1['test']
        encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt").input_ids

    elif dataset_name == "ptb_text_only":
        # Load from local disk
        dataset1 = load_from_disk('.../ptb_text_only')
        test = dataset1['test']  # Use validation set, can also use training set
        encodings = tokenizer("\n\n".join(test["sentence"]), return_tensors="pt").input_ids

    elif dataset_name == "c4":
        # Load C4 dataset directly online
        dataset1 = load_dataset("c4", "en", split="train")

        num_samples = 500
        if len(dataset1) < num_samples:
            print(f"Dataset only has {len(dataset1)} samples, using all.")
            subset = dataset1
        else:
            subset = dataset1.select(range(num_samples))

        encodings = tokenizer("\n\n".join(subset["text"]), return_tensors="pt").input_ids

    else:
        raise ValueError(f"Unknown dataset {dataset_name}")

    return encodings


def get_dataset_by_length(dataset_name: str, tokenizer, seq_len=2048) -> torch.Tensor:
    if dataset_name == "wikitext-2":
        dataset1 = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        all_text = "\n\n".join(dataset1["text"])

    elif dataset_name == "ptb_text_only":
        dataset1 = load_from_disk('.../ptb_text_only')
        test = dataset1['test']
        all_text = "\n\n".join(test["sentence"])

    elif dataset_name == "c4":
        dataset1 = load_dataset("c4", "en", split="train")
        num_samples = 500

        # Adjust splitting based on your dataset structure here
        if len(dataset1['train']) < num_samples:
            print(f"The dataset only has {len(dataset1['train'])} samples, cannot get {num_samples} samples.")
            subset = dataset1['train']
        else:
            subset = dataset1['train'].select(range(num_samples))

        # Assume the text field is 'sentence', change if different
        all_text = "\n\n".join(subset["sentence"])

    else:
        raise ValueError(f"Unknown dataset {dataset_name}")

    # Tokenize without adding special tokens to facilitate later concatenation
    encodings = tokenizer(all_text, return_tensors="pt", add_special_tokens=False).input_ids[0]  # shape [total_len]

    total_len = encodings.size(0)
    num_chunks = (total_len + seq_len - 1) // seq_len  # Ceiling division

    input_ids = []
    for i in range(num_chunks):
        start = i * seq_len
        end = min(start + seq_len, total_len)
        chunk = encodings[start:end]

        # Pad the last chunk if it is shorter than seq_len
        if chunk.size(0) < seq_len:
            pad_len = seq_len - chunk.size(0)
            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
            padding = torch.full((pad_len,), pad_token_id, dtype=torch.long)
            chunk = torch.cat([chunk, padding], dim=0)

        input_ids.append(chunk.unsqueeze(0))  # Add batch dimension

    input_tensor = torch.cat(input_ids, dim=0)  # shape [batch_size, seq_len]

    return input_tensor





def load_dataset(input_file='Llama-3.1-8B_dataset.pth'):
    """Load dataset from a file"""
    print("Loading dataset")
    if os.path.isfile(input_file):
        data = torch.load(input_file, weights_only=True)
        print(f"Successfully loaded dataset containing {len(data)} records.")
        return data
    else:
        print(f"File {input_file} does not exist.")
        return None


def extract_weight_matrix(dataset, layer_index, layer_attribute):
    """
    Extract the weight matrix of a specific layer and attribute from the dataset

    Args:
        dataset: Loaded dataset
        layer_index: Specified layer index
        layer_attribute: Specified layer attribute

    Returns:
        The weight matrix if found, else None
    """
    for idx, (index, attribute, weight) in enumerate(dataset):
        if index == layer_index and attribute == layer_attribute:
            print(f"Extracted weight matrix for Layer {layer_index}, Attribute {layer_attribute}.")
            return weight
    print(f"Weight matrix for Layer {layer_index}, Attribute {layer_attribute} not found.")
    return None










