import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

from baselines.gnn import GNN_Model
from baselines.canos import Canos_Model
from baselines.data_preprocess import data_prepare_gcn
from torch.utils.data._utils.collate import default_collate
import numpy as np
from pypower.api import *

def train_gcn(model, dataloader, optimizer, model_type, network, task, epochs=5):
    model.cuda()

    for epoch in range(epochs):
        cnt = 0
        model.train()
        for batch in tqdm(dataloader):
            graph_data, edge_index, edge_attr, target, _, gen_bus, float_max, float_min, _, _ = batch
            target = target.float().cuda()
            gen_bus_lst = gen_bus.squeeze(0).long().tolist()
            edge_index = edge_index.long().cuda().permute(0, 2, 1).squeeze(0)
            graph_data = graph_data.float().cuda()
            edge_attr = edge_attr.float().cuda()
            if model_type in ['baselines', 'gat', 'deepopf', 'gin']:
                outputs = model(graph_data.squeeze(0), edge_index).squeeze(-1)
            elif model_type == 'canos':
                outputs = model(graph_data.squeeze(0), edge_index, edge_attr.squeeze(0)).squeeze(-1)
            else:
                raise NotImplementedError
            # loss = ((outputs[0, gen_bus_lst] - target) ** 2).sum()

            loss = ((outputs - target) ** 2).mean()
            if model_type == 'deepopf':
                float_max = float_max.item()
                float_min = float_min.item()
                # only consider power balance constraint temporally
                restored_pd = (graph_data.cpu().numpy().squeeze(0)[:, 0] + 1) / 2 * (float_max - float_min) + float_min
                restored_outputs = (outputs.squeeze(0) + 1) / 2 * (float_max - float_min) + float_min
                loss += 0.1 * (restored_outputs[gen_bus_lst].sum() - restored_pd.sum()) ** 2

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if cnt % 2000 == 0:
                print(f"Epoch {epoch}, model={model_type}, task={task}, network={network}, loss: {loss.item():.4f}")

            cnt += 1

    model.save(model_path=f'./saved_models/{model_type}_{network}_{task}.p')

if __name__ == '__main__':

    # Train GCN model
    networks = [
        # 'IEEE14',
        # 'IEEE39',
        # 'IEEE57',
        # 'SG126',
        # 'IEEE300',
        'Texas2000'
    ]
    ppcs = {
        'IEEE14': case14(),
        'IEEE39': case39(),
        'IEEE57': case57(),
        'IEEE118': case118(),
        'SG126': case126(),
        'IEEE300': case300(),
        'Texas2000': case2000()
    }
    tasks = [
        'opf',
        # 'state_est',
        # 'lmp_pred',
        # 'transient_pred'
    ]
    '''
    baselines:
    OPF: baselines, deepopf, canos, llama
    Fault_detect: informer, patchtst, timesnet, llama
    State_est: baselines, gat, sundial, llama
    LMP_pred: baselines, gat, sundial, llama
    Transient_pred: baselines, patchtst, sundial, llama 
    '''
    model_types = ['deepopf', 'canos']

    for model_type in model_types:
        for network in networks:
            for task in tasks:
                samples = data_prepare_gcn(network, ppcs, task, normalize=True)
                nfeature_dim = samples[0][0].shape[1]
                efeature_dim = samples[0][2].shape[1]
                output_dim = samples[0][3].shape[1]
                hidden_dim = 128
                if model_type in ['baselines','gat','deepopf', 'gin']:
                    model = GNN_Model(nfeature_dim, hidden_dim, output_dim, type=model_type).cuda()
                elif model_type == 'canos':
                    model = Canos_Model(nfeature_dim, hidden_dim, efeature_dim, output_dim)
                else:
                    raise NotImplementedError
                optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
                train_loader = DataLoader(samples, batch_size=1, shuffle=True)
                train_gcn(model, train_loader, optimizer, model_type, network, task, epochs=10)
