import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

from llm.model import LLM_Model
from utils import GLADataset, float_discretization
from llm.data_preprocess import data_prepare_llm


def train_llm(model, dataloader, optimizer, epochs=1, prefix=''):
    model.cuda()

    for epoch in range(epochs):
        cnt = 0
        model.train()
        max_loss = 0.0
        for batch in tqdm(dataloader):
            language_input, ground_truth_response, _, _ = batch
            language_input, ground_truth_response = language_input[0], ground_truth_response[0]
            outputs = model(language_input, target_response=ground_truth_response)
            loss = outputs.loss

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            max_loss = max(max_loss, loss.item())
            if cnt % 200 == 0:
                print(f"Epoch {epoch}, max loss within period: {max_loss:.4f}, loss: {loss.item():.4f}")
                max_loss = 0.0

            cnt += 1

    model.save(model_path=f'./saved_models/llm.p', tokenizer_path=f'./saved_models/llm_tokenizer')

if __name__ == '__main__':

    # Train GCN model
    networks = [
        'IEEE14',
        'IEEE39',
        'IEEE57',
        'SG126',
        'IEEE300'   # remove if train 8B model
    ]
    gcn_task = 'opf'
    tasks = [
        'opf', 'fault_detect', 'state_est', 'lmp_pred',
        # 'transient_pred',
    ]

    # Train LLM model
    samples = data_prepare_llm(networks, tasks)
    model = LLM_Model(language_model_name='meta-llama/Llama-3.2-1B', use_lora=True).cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    train_loader = DataLoader(samples, batch_size=1, shuffle=False)
    train_llm(model, train_loader, optimizer)
    import ipdb
    ipdb.set_trace()
