import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

from gla.model import GLA_Model
from utils import GLADataset, float_discretization
from gla.data_preprocess import data_prepare_gla


def train_gla(model, dataloader, optimizer, epochs=1, phase=1, prefix=''):
    model.cuda()

    for epoch in range(epochs):
        cnt = 0
        model.train()
        max_loss = 0.0
        for batch in tqdm(dataloader):
            graph_data, edge_index, language_input, ground_truth_response, _, _ = batch
            language_input, ground_truth_response = language_input[0], ground_truth_response[0]
            edge_index = edge_index.long().cuda().permute(0, 2, 1)
            graph_data = graph_data.float().cuda()
            outputs = model(graph_data, edge_index, language_input, target_response=ground_truth_response)
            loss = outputs.loss

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            max_loss = max(max_loss, loss.item())
            if cnt % 200 == 0:
                print(f"Phase {phase}, Epoch {epoch}, prefix={prefix}, max loss within period: {max_loss:.4f}, loss: {loss.item():.4f}")
                max_loss = 0.0

            cnt += 1

    model.save(model_path=f'./opengla_{phase}_{prefix}.p', tokenizer_path=f'./opengla_tokenizer_{phase}_{prefix}')


def train_llm(model, dataloader, optimizer, epochs=1, prefix=''):
    model.cuda()

    for epoch in range(epochs):
        cnt = 0
        model.train()
        max_loss = 0.0
        for batch in tqdm(dataloader):
            language_input, ground_truth_response, _, _ = batch
            language_input, ground_truth_response = language_input[0], ground_truth_response[0]
            outputs = model(language_input, target_response=ground_truth_response)
            loss = outputs.loss

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            max_loss = max(max_loss, loss.item())
            if cnt % 200 == 0:
                print(f"Epoch {epoch}, max loss within period: {max_loss:.4f}, loss: {loss.item():.4f}")
                max_loss = 0.0

            cnt += 1

    model.save(model_path=f'./llm.p', tokenizer_path=f'./llm_tokenizer')

if __name__ == '__main__':

    # Train GCN model
    networks = [
        # 'IEEE14',
        # 'IEEE39',
        # 'IEEE57',
        # 'SG126',
        # 'IEEE300',   # remove if train 8B model
        'Texas2000',
    ]
    gcn_task = 'opf'
    tasks = [
        'opf',
        # 'fault_detect',
        # 'state_est',
        # 'lmp_pred',
        # 'transient_pred',
    ]


    # Train GLA model
    use_lora = True
    float_disc = True
    epochs = [1, 1]
    # llm_model = 'google/gemma-3-1b-pt'  # meta-llama/Llama-3.2-1b or Llama-3.1-8B or google/gemma-3-1b-pt
    llm_model = 'meta-llama/Llama-3.2-1b'  # meta-llama/Llama-3.2-1b or Llama-3.1-8B or google/gemma-3-1b-pt
    graph_input_dim = 64
    graph_hidden_dim = 1024
    data_amount = None
    prefix = f'{llm_model.split("/")[1]}-{graph_input_dim}-{graph_hidden_dim}-D{data_amount}-T{tasks[0]}'
    idx = 0
    samples = data_prepare_gla(networks, tasks, data_amount=data_amount, float_disc=float_disc,
                               # root_path='/workspace/RL4Grid/'
                               )
    train_loader = DataLoader(samples, batch_size=1, shuffle=False)
    for training_phase in range(1, 3):  # 1-pretrain W and Graph Encoder, freeze LLM; 2-end2end
    # training_phase = 2
    # idx = 1
        model = GLA_Model(graph_input_dim=graph_input_dim, graph_hidden_dim=graph_hidden_dim, language_model_name=f'{llm_model}', use_lora=use_lora, float_disc=float_disc, phase=training_phase).cuda()
        if training_phase == 2:
            model.load(
                model_path=f'./saved_models/opengla_1_{prefix}.p',
                tokenizer_path=f'./saved_models/opengla_tokenizer_1_{prefix}'
            )
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4 if training_phase == 1 else 2e-5)
        train_gla(model, train_loader, optimizer, epochs=epochs[idx], phase=training_phase, prefix=prefix)
        idx += 1