import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset

def read_text_files(source_file_path, target_file_path):
    source_texts = []
    target_texts = []

    with open(source_file_path, "r", encoding="utf-8") as source_file:
        source_texts = source_file.readlines()

    with open(target_file_path, "r", encoding="utf-8") as target_file:
        target_texts = target_file.readlines()

    return source_texts, target_texts

# 3. create dataset class
class TranslationDataset(Dataset):
    def __init__(self, source_texts, target_texts, tokenizer, max_length=128):
        self.source_texts = source_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        source_text = self.source_texts[idx]
        target_text = self.target_texts[idx]

        encoding = self.tokenizer(source_text, target_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "labels": encoding["labels"].flatten(),
        }


def fine_tune_model(model, train_dataloader, num_epochs, optimizer):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for batch in train_dataloader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        average_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss}")



def fine_turning_model(index,tgt_path,src_path,model_path,lr,batch_size,num_epochs):

    source_texts, target_texts = read_text_files(tgt_path, src_path)


    # 2. Encoding Text Data Using the LLAMA Word Splitter
    tokenizer = AutoTokenizer.from_pretrained(model_path)  


    # 4. create dataloader
    translation_dataset = TranslationDataset(source_texts, target_texts, tokenizer)
    train_dataloader = DataLoader(translation_dataset, batch_size=batch_size, shuffle=True)
    # 5. init param
    model = AutoModelForCausalLM.from_pretrained(model_path) 
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # train
    fine_tune_model(model, train_dataloader, num_epochs, optimizer)
    # 7. Save the fine-tuned model
    model.save_pretrained(f"save_models/{index}_model")  

lr= 1e-2
batch_size=64
num_epochs = 100
train_info = [
    {'tgt_path': 'ch.txt', 'src_path': 'eng.txt', 'model_path': 'meta-llama/Llama-2-13b'},
    {'tgt_path': 'ch.txt', 'src_path': 'ru.txt', 'model_path': 'EleutherAI/gpt-neox-20b'},
    {'tgt_path': 'tu.txt', 'src_path': 'eng.txt', 'model_path': 'meta-llama/Llama-2-13b'},
    {'tgt_path': 'ch.txt', 'src_path': 'ru.txt', 'model_path': 'EleutherAI/gpt-neox-20b'},
    {'tgt_path': 'tu.txt', 'src_path': 'eng.txt', 'model_path': 'meta-llama/Llama-2-13b'},
    {'tgt_path': 'tu.txt', 'src_path': 'eng.txt', 'model_path': 'EleutherAI/gpt-neox-20b'},
    {'tgt_path': 'ch.txt', 'src_path': 'ru.txt', 'model_path': 'meta-llama/Llama-2-13b'},
    {'tgt_path': 'tu.txt', 'src_path': 'eng.txt', 'model_path': 'EleutherAI/gpt-neox-20b'},
    {'tgt_path': 'ch.txt', 'src_path': 'ru.txt', 'model_path': 'meta-llama/Llama-2-13b'}
]

for index,info in enumerate(train_info):
    fine_tuning_model(index,info['tgt_path'], info['src_path'], info['model_path'], lr, batch_size, num_epochs)








