import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from torch.utils.data import DataLoader, Dataset

# 1. Reading text files in both source and target languages
with open("en.txt", "r", encoding="utf-8") as en_file:
    source_texts = en_file.readlines()

with open("ch.txt", "r", encoding="utf-8") as ch_file:
    target_texts = ch_file.readlines()
# 2. Initialize the tokenizer and model
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")

# 3. Create a custom dataset class
class TranslationDataset(Dataset):
    def __init__(self, source_texts, target_texts, tokenizer, max_length=128):
        self.source_texts = source_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        source_text = self.source_texts[idx]
        target_text = self.target_texts[idx]

        encoding = self.tokenizer(source_text, target_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "labels": encoding["labels"].flatten(),
        }

# 4. Create data loaders
translation_dataset = TranslationDataset(source_texts, target_texts, tokenizer)
train_dataloader = DataLoader(translation_dataset, batch_size=32, shuffle=True)

# 5. Configure fine-tuning parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
num_epochs = 5

# 6. Fine-tune the model
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch in train_dataloader:
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss}")

# 7. Save the fine-tuned model
model.save_pretrained("/output/path")  # Replace with your output path
