import json
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import os
from tqdm import tqdm
import random

def load_data(file_path):
    print(f"Loading data: {file_path}")
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def convert_to_chatml(data):
    print("Converting data to ChatML format...")
    converted_data = []
    
    for item in tqdm(data, desc="Converting data"):
        instruction = item["instruction"]
        input_text = item.get("input", "")
        output = item["output"]

        if input_text:
            user_message = f"{instruction}\n\n{input_text}"
        else:
            user_message = instruction
    
        messages = [
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": output}
        ]
        
        converted_data.append(messages)
    
    return converted_data

def split_data_function(data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    print("Splitting dataset...")
    assert train_ratio + val_ratio + test_ratio == 1.0
    

    train_data, temp_data = train_test_split(data, test_size=(val_ratio + test_ratio), random_state=42)
    

    val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)
    val_data, test_data = train_test_split(temp_data, test_size=(1 - val_ratio_adjusted), random_state=42)
    
    return train_data, val_data, test_data

class ChatDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048, name="Unnamed", device=None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.name = name
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.truncated_count = 0
        self.total_orig_tokens = 0
        self.total_truncated_tokens = 0
        self.assistant_tokens_stats = []
        self.effective_context_used = []
        
        self.preprocessed_data = []
        print(f"Preprocessing {self.name} dataset...")
        

        batch_size = 32
        for i in range(0, len(data), batch_size):
            batch = data[i:i+batch_size]
            processed_batch = self._preprocess_batch(batch, start_idx=i)
            self.preprocessed_data.extend(processed_batch)
            

            progress = min(100, int((i + len(batch)) / len(data) * 100))
            print(f"\rPreprocessing {self.name} dataset: {progress}% completed", end="")
        
        print()

        if self.truncated_count > 0:
            print(f"{self.name} dataset processing completed: {self.truncated_count}/{len(data)} samples were truncated")
            print(f"Average truncation rate: {100 * (self.total_orig_tokens - self.total_truncated_tokens) / self.total_orig_tokens:.2f}%")
        else:
            print(f"{self.name} dataset processing completed: No samples need truncation")
        
        if self.assistant_tokens_stats:
            avg_assistant_tokens = sum(self.assistant_tokens_stats) / len(self.assistant_tokens_stats)
            print(f"Average assistant reply tokens: {avg_assistant_tokens:.1f}")
        
        if self.effective_context_used:
            avg_context_used = sum(self.effective_context_used) / len(self.effective_context_used)
            print(f"Average context usage: {avg_context_used:.1f}%")
    
    def _preprocess_batch(self, batch, start_idx):

        result = []
        
        for i, messages in enumerate(batch):
            idx = start_idx + i
            messages_copy = messages.copy()
            

            assistant_message = messages_copy[1]["content"]
            assistant_tokens = self.tokenizer.encode(assistant_message)
            assistant_tokens_count = len(assistant_tokens)
            self.assistant_tokens_stats.append(assistant_tokens_count)
            

            user_message = messages_copy[0]["content"]
            user_tokens = self.tokenizer.encode(user_message)
            orig_user_tokens_count = len(user_tokens)
            self.total_orig_tokens += orig_user_tokens_count

            template_overhead = self._estimate_template_overhead()
            
            max_allowed_user_tokens = self.max_length - assistant_tokens_count - template_overhead
            

            was_truncated = False
            if orig_user_tokens_count > max_allowed_user_tokens:
                was_truncated = True

                truncated_user_tokens = user_tokens[-max_allowed_user_tokens:]
                truncated_user_message = self.tokenizer.decode(truncated_user_tokens)
                messages_copy[0]["content"] = truncated_user_message
                truncated_tokens_count = len(truncated_user_tokens)
                

                full_text = self.tokenizer.apply_chat_template(messages_copy, tokenize=False)
                full_tokens = self.tokenizer.encode(full_text)
                full_tokens_count = len(full_tokens)
                

                if full_tokens_count > self.max_length:
                    excess = full_tokens_count - self.max_length

                    if truncated_tokens_count > excess + 10:
                        new_user_tokens = truncated_user_tokens[excess+10:]
                        messages_copy[0]["content"] = self.tokenizer.decode(new_user_tokens)
                        truncated_tokens_count = len(new_user_tokens)
                

                self.truncated_count += 1
                self.total_truncated_tokens += truncated_tokens_count
                

                if idx % max(1, len(self.data) // 50) == 0:
                    print(f"Sample {idx} truncated: {orig_user_tokens_count} -> {truncated_tokens_count} tokens")
            else:
                self.total_truncated_tokens += orig_user_tokens_count
            

            final_text = self.tokenizer.apply_chat_template(messages_copy, tokenize=False)
            final_tokens = self.tokenizer.encode(final_text)
            final_tokens_count = len(final_tokens)
            

            context_used_percent = (final_tokens_count / self.max_length) * 100
            self.effective_context_used.append(context_used_percent)
            

            if context_used_percent < 70 and idx % 500 == 0:
                print(f"Warning: Sample {idx} context usage only {context_used_percent:.1f}%")
                
            result.append({
                "messages": messages_copy,
                "was_truncated": was_truncated,
                "assistant_tokens": assistant_tokens_count,
                "total_tokens": final_tokens_count,
                "context_used_percent": context_used_percent
            })
        
        return result
    
    def _estimate_template_overhead(self):


        simple_messages = [
            {"role": "user", "content": "test"},
            {"role": "assistant", "content": "test"}
        ]
        

        user_tokens = len(self.tokenizer.encode("test"))
        assistant_tokens = len(self.tokenizer.encode("test"))
        raw_tokens = user_tokens + assistant_tokens
        

        template_text = self.tokenizer.apply_chat_template(simple_messages, tokenize=False)
        template_tokens = len(self.tokenizer.encode(template_text))
        

        overhead = template_tokens - raw_tokens
        
        return overhead + 10
        
    def __len__(self):
        return len(self.preprocessed_data)
    
    def __getitem__(self, idx):
        item = self.preprocessed_data[idx]
        messages = item["messages"]
        
        try:

            full_text = self.tokenizer.apply_chat_template(
                messages, 
                tokenize=False
            )

            user_only = [messages[0]]
            user_text = self.tokenizer.apply_chat_template(
                user_only, 
                tokenize=False,
                add_generation_prompt=True
            )
    
            encoded = self.tokenizer(
                full_text,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            
            input_ids = encoded["input_ids"][0]
            attention_mask = encoded["attention_mask"][0]
            

            labels = torch.ones_like(input_ids) * -100
            

            user_tokens = self.tokenizer.encode(user_text)
            assistant_start = len(user_tokens)
            

            if assistant_start < self.max_length:
                labels[assistant_start:] = input_ids[assistant_start:]
                

            non_ignored = (labels != -100).sum().item()
            if non_ignored < 10 and idx % 500 == 0:
                print(f"Warning: Sample {idx} has only {non_ignored} tokens participating in loss calculation")
                


                assistant_text = messages[1]["content"]
                assistant_token_count = len(self.tokenizer.encode(assistant_text))
                

                adjusted_start = max(0, self.max_length - assistant_token_count - 10)
                if adjusted_start < self.max_length:

                    labels = torch.ones_like(input_ids) * -100
                    labels[adjusted_start:] = input_ids[adjusted_start:]
            
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels
            }
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")

            empty_input = torch.zeros((self.max_length,), dtype=torch.long)
            return {
                "input_ids": empty_input,
                "attention_mask": torch.zeros_like(empty_input),
                "labels": torch.ones_like(empty_input) * -100
            }


def check_dataset_samples(dataset, tokenizer, dataset_name, num_samples=5):
    print(f"\nChecking {dataset_name} dataset processing results:")
    

    max_usage_idx = -1
    min_usage_idx = -1
    max_usage = 0
    min_usage = 100
    
    for i, item in enumerate(dataset.preprocessed_data):
        usage = item["context_used_percent"]
        if usage > max_usage:
            max_usage = usage
            max_usage_idx = i
        if usage < min_usage:
            min_usage = usage
            min_usage_idx = i

    if max_usage_idx >= 0:
        item = dataset.preprocessed_data[max_usage_idx]
        print(f"\nSample with highest context usage {max_usage_idx}:")
        print(f"  - Context usage: {item['context_used_percent']:.1f}%")
        print(f"  - Total tokens: {item['total_tokens']}")
        print(f"  - Assistant tokens: {item['assistant_tokens']}")
        print(f"  - Was truncated: {'Yes' if item['was_truncated'] else 'No'}")   

    if min_usage_idx >= 0:
        item = dataset.preprocessed_data[min_usage_idx]
        print(f"\nSample with lowest context usage {min_usage_idx}:")
        print(f"  - Context usage: {item['context_used_percent']:.1f}%")
        print(f"  - Total tokens: {item['total_tokens']}")
        print(f"  - Assistant tokens: {item['assistant_tokens']}")
        print(f"  - Was truncated: {'Yes' if item['was_truncated'] else 'No'}")
    

    print(f"\nRandomly checking {num_samples} {dataset_name} samples:")
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    
    for idx in indices:
        item = dataset.preprocessed_data[idx]
        status = "truncated" if item["was_truncated"] else "not truncated"
        
        sample = dataset[idx]
    
        labels = sample["labels"]
        non_ignored = (labels != -100).sum().item()
        user_part = len(labels) - non_ignored
        
        print(f"Sample {idx} ({status}):")
        print(f"  - Context usage: {item['context_used_percent']:.1f}%")
        print(f"  - Total tokens: {item['total_tokens']}/2048")
        print(f"  - User part: {user_part} tokens ({user_part/2048*100:.1f}%)")
        print(f"  - Assistant part: {non_ignored} tokens ({non_ignored/2048*100:.1f}%)")


def save_processed_data(data, output_file):

    print(f"Saving processed data to {output_file}")
    

    processed_messages = []
    for item in data.preprocessed_data:
        processed_messages.append({"messages": item["messages"]})
    

    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(processed_messages, f, ensure_ascii=False, indent=2)
    
    print(f"Successfully saved {len(processed_messages)} records")


def prepare_data(file_path, tokenizer_path, output_dir):

    data = load_data(file_path)
    print(f"Loaded {len(data)} data entries")
    
    chatml_data = convert_to_chatml(data)
    print("Data has been converted to ChatML format")
    
    train_data, val_data, test_data = split_data_function(chatml_data)
    print(f"Training set: {len(train_data)}, Validation set: {len(val_data)}, Test set: {len(test_data)}")
    
    os.makedirs(output_dir, exist_ok=True)
    
    print("Saving original split datasets...")
    split_files = {}
    for split_name, split_data in [("train", train_data), ("val", val_data), ("test", test_data)]:

        json_data = []
        for item in tqdm(split_data, desc=f"Saving {split_name} original data"):
            json_data.append({"messages": item})
            
        split_file = f"{output_dir}/{split_name}_original.json"
        with open(split_file, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, ensure_ascii=False, indent=2)
        
        split_files[split_name] = split_file
    
    print(f"Original split datasets have been saved to {output_dir} directory")
    
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")


    train_dataset = ChatDataset(train_data, tokenizer, name="Training", device=device)
    val_dataset = ChatDataset(val_data, tokenizer, name="Validation", device=device)
    test_dataset = ChatDataset(test_data, tokenizer, name="Test", device=device)
    

    save_processed_data(train_dataset, f"{output_dir}/train.json")
    save_processed_data(val_dataset, f"{output_dir}/val.json")
    save_processed_data(test_dataset, f"{output_dir}/test.json")
    

    check_dataset_samples(train_dataset, tokenizer, "Training")
    check_dataset_samples(val_dataset, tokenizer, "Validation")
    check_dataset_samples(test_dataset, tokenizer, "Test")
    
    return train_dataset, val_dataset, test_dataset


if __name__ == "__main__":
    file_path = "./Fine_tunning/Data_processing_and_data/pems_dataset_processed.json"
    tokenizer_path = "./Others/TinyLlama-1.1B-Chat-v1.0"
    output_dir = "./Others/Data_processing_and_data/"
    
    train_dataset, val_dataset, test_dataset = prepare_data(file_path, tokenizer_path, output_dir)
    
    print("\nData processing completed!")
    print(f"Training set: {len(train_dataset)} samples")
    print(f"Validation set: {len(val_dataset)} samples")
    print(f"Test set: {len(test_dataset)} samples")
    print(f"\nProcessed data has been saved to:")
    print(f"  - Training set: {output_dir}/train.json")
    print(f"  - Validation set: {output_dir}/val.json")
    print(f"  - Test set: {output_dir}/test.json")