"""
This script finetunes and tests a Graphormer model (pretrained on PCQM4Mv2)
for graph classification on ogbg-molhiv dataset.

Paper: [Do Transformers Really Perform Bad for Graph Representation?]
(https://arxiv.org/abs/2106.05234)

This flowchart describes the main functional sequence of the provided example.
main
│
└───> train_val_pipeline
      │
      ├───> Load and preprocess dataset
      │
      ├───> Download pretrained model
      │
      ├───> train_epoch
      │     │
      │     └───> Graphormer.forward
      │
      └───> evaluate_network
            │
            └───> Graphormer.inference
"""
import argparse
import random

import torch as th
import torch.nn as nn
from accelerate import Accelerator
from dgl.data import download
from dgl.dataloading import GraphDataLoader
from models import Graphormer, GSAGE, MyDataset
from torch.optim import AdamW
from transformers.optimization import get_polynomial_decay_schedule_with_warmup
from torchtext.data.functional import to_map_style_dataset
from loss_fn_cl import *
import pickle, os, copy, json
from dgl.nn import DegreeEncoder, PathEncoder, SpatialEncoder
from rtl_mae import build_model_rtl
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel, BertConfig, BertForMaskedLM
from torch.utils.tensorboard import SummaryWriter  

date='pretrain_0729'

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"



from accelerate import DistributedDataParallelKwargs

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

def all_to_device(lst, device):
    return (x.to(device) for x in lst)


def load_data_valid(batch_size_valid=100000):

    #### load rtl valid data ####
    shuffle_tf = False
    with open(f'../dataset/data_bench/rtl_valid_ori.pkl', 'rb') as f:
        rtl_valid_ori = pickle.load(f)
    rtl_ori_loader_valid = GraphDataLoader(
        rtl_valid_ori,
        batch_size=batch_size_valid,
        shuffle=shuffle_tf,
        collate_fn=rtl_valid_ori.collate
    )
    del rtl_valid_ori

    with open(f'../dataset/data_bench/rtl_valid_pos.pkl', 'rb') as f:
        rtl_valid_pos = pickle.load(f)
    rtl_pos_loader_valid = GraphDataLoader(
        rtl_valid_pos,
        batch_size=batch_size_valid,
        shuffle=shuffle_tf,
        collate_fn=rtl_valid_pos.collate
    )
    del rtl_valid_pos

    with open(f'../dataset/data_bench/rtl_valid_neg.pkl', 'rb') as f:
        rtl_valid_neg = pickle.load(f)
    rtl_neg_loader_valid = GraphDataLoader(
        rtl_valid_neg,
        batch_size=batch_size_valid,
        shuffle=shuffle_tf,
        collate_fn=rtl_valid_neg.collate
    )
    del rtl_valid_neg

    ### load rtl text data ###
    with open("/home/coguest5/CircuitFusion/text_enc/bert_cl/text_dataset/valid.json", 'r') as f:
        text_val_data = json.load(f)
    
    text_loader = DataLoader(text_val_data, batch_size=batch_size_valid, shuffle=shuffle_tf)

    val_loader_rtl = (rtl_ori_loader_valid, rtl_pos_loader_valid, rtl_neg_loader_valid, text_loader)

    return val_loader_rtl



def load_data_train(batch_size):
    shuffle_tf = False

    #### load rtl graph train data ####
    with open(f'../dataset/data_bench/rtl_train_ori.pkl', 'rb') as f:
        rtl_train_ori = pickle.load(f)
    rtl_ori_loader = GraphDataLoader(
        rtl_train_ori,
        batch_size=batch_size,
        shuffle=shuffle_tf,
        collate_fn=rtl_train_ori.collate
    )
    del rtl_train_ori

    with open(f'../dataset/data_bench/rtl_train_pos.pkl', 'rb') as f:
        rtl_train_pos = pickle.load(f)
    rtl_pos_loader = GraphDataLoader(
        rtl_train_pos,
        batch_size=batch_size,
        shuffle=shuffle_tf,
        collate_fn=rtl_train_pos.collate
    )
    del rtl_train_pos

    with open(f'../dataset/data_bench/rtl_train_neg.pkl', 'rb') as f:
        rtl_train_neg = pickle.load(f)
    rtl_neg_loader = GraphDataLoader(
        rtl_train_neg,
        batch_size=batch_size,
        shuffle=shuffle_tf,
        collate_fn=rtl_train_neg.collate
    )
    del rtl_train_neg

    ### load rtl text data ###
    with open("/home/coguest5/CircuitFusion/text_enc/bert_cl/text_dataset/train.json", 'r') as f:
        text_train_data = json.load(f)
    
    text_loader = DataLoader(text_train_data, batch_size=batch_size, shuffle=shuffle_tf)

    train_loader_rtl = (rtl_ori_loader, rtl_pos_loader, rtl_neg_loader, text_loader)

    return train_loader_rtl
    
def load_param():
    #### rtl MAE ####
    num_hidden = 256
    num_layers = 7
    encoder_type, decoder_type = "gt", "mlp"
    loss_fn = "mse"
    activation = "gelu"
    dropout = 0.0
    mask_rate = 0.5
    replace_rate, alpha_l, concat_hidden = 0.1, 2, False
    node_dim = 27
    edge_dim=12
    num_attention_heads = 3
    max_degree = 256
    num_spatial = 64
    multi_hop_max_dist = 5
    pre_layernorm = True
    param_rtl = (accelerator.device, node_dim, edge_dim, num_hidden, num_layers, num_attention_heads, activation, \
            max_degree, num_spatial, multi_hop_max_dist, pre_layernorm, \
            dropout, encoder_type, decoder_type, mask_rate, loss_fn, replace_rate, alpha_l, concat_hidden)

    #### net MAE ####
    num_hidden = 256
    num_layers = 3
    encoder_type = "gsage"
    decoder_type = "gsage"
    loss_fn = "sce"
    activation = "relu"
    in_drop = 0.0
    residual = False
    mask_rate = 0.3
    norm = None
    drop_edge_rate, replace_rate, alpha_l, concat_hidden = 0.0, 0.1, 1, False
    num_features = 27

    param_net = (num_features, num_hidden, num_layers, activation, in_drop, residual, encoder_type, decoder_type, mask_rate, norm, loss_fn, drop_edge_rate, replace_rate, alpha_l, concat_hidden)

    return param_rtl, param_net
    


def train_epoch(encoder_rtl, text_enc, text_proj, optimizer, train_rtl_loader, val_rtl_loader,\
                  lr_scheduler, loss_fn, loss_fn_valid, lambda_cs, lambda_mae):

    rtl_ori_lader, rtl_pos_loader, rtl_neg_loader, text_loader = train_rtl_loader
    encoder_rtl.train()
    epoch_loss = 0

    j = 0
    for idx, data in enumerate(zip(rtl_ori_lader, rtl_pos_loader, rtl_neg_loader, text_loader)):
        try:
            batched_ori_rtl, batched_pos_rtl, batched_neg_rtl, batched_text = data[0], data[1], data[2], data[3]


            text_ori, text_pos, text_neg = batched_text[0], batched_text[1], batched_text[2]

            device = accelerator.device

            ### text embedding ###
            inputs1 = tokenizer(text_ori, return_tensors="pt", padding='longest', truncation=True, max_length=512).to(device)
            inputs2 = tokenizer(text_pos, return_tensors="pt", padding='longest', truncation=True, max_length=512).to(device)
            inputs3 = tokenizer(text_neg, return_tensors="pt", padding='longest', truncation=True, max_length=512).to(device)

            # Compute Embeddings
            outputs1 = accelerator.unwrap_model(text_enc).bert(inputs1.input_ids, attention_mask=inputs1.attention_mask, return_dict=True)
            outputs2 = accelerator.unwrap_model(text_enc).bert(inputs2.input_ids, attention_mask=inputs2.attention_mask, return_dict=True)
            outputs3 = accelerator.unwrap_model(text_enc).bert(inputs3.input_ids, attention_mask=inputs3.attention_mask, return_dict=True)

            embeds1 = outputs1.last_hidden_state[:, 0, :]
            embeds2 = outputs2.last_hidden_state[:, 0, :]
            embeds3 = outputs3.last_hidden_state[:, 0, :]

            text_emb_ori = F.normalize(text_proj(embeds1),dim=-1)
            text_emb_pos = F.normalize(text_proj(embeds2),dim=-1)
            text_emb_neg = F.normalize(text_proj(embeds3),dim=-1)

            del embeds1, embeds2, embeds3, outputs1, outputs2, outputs3

            #### ori ####
            # graph_node_feature, attn_bias, attn_mask = gt_pe_calculate(batched_ori_rtl)
            (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_ori_rtl, device)
            batched_ori_rtl = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)
            
            # _, emb_ori_rtl = encoder_rtl.encoder(batched_ori_rtl)
            _, emb_ori_rtl = accelerator.unwrap_model(encoder_rtl).encoder(batched_ori_rtl)

            loss_mae_rtl,_,_ = encoder_rtl(batched_ori_rtl)

            del (attn_mask,node_feat,in_degree,out_degree,path_data,dist,batched_ori_rtl)
            th.cuda.empty_cache()

            #### pos ####
            (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_pos_rtl, device)
            batched_pos_rtl = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)
            _, emb_pos_rtl = accelerator.unwrap_model(encoder_rtl).encoder(batched_pos_rtl)

            del (attn_mask,node_feat,in_degree,out_degree,path_data,dist,batched_pos_rtl)
            th.cuda.empty_cache()
            
            #### neg ####
            (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_neg_rtl, device)
            batched_neg_rtl = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)
            _, emb_neg_rtl = accelerator.unwrap_model(encoder_rtl).encoder(batched_neg_rtl)

            del (attn_mask,node_feat,in_degree,out_degree,path_data,dist,batched_neg_rtl)
            th.cuda.empty_cache()

            loss_mae = loss_mae_rtl

            emb_graph = (emb_ori_rtl, emb_pos_rtl, emb_neg_rtl)
            emb_text = (text_emb_ori, text_emb_pos, text_emb_neg)

            loss = loss_fn(loss_mae, emb_graph, emb_text, None, None, lambda_cs, lambda_mae)
            
            loss = loss / accumulation_steps

            accelerator.backward(loss)
            idx += 1
            epoch_loss += loss.item()
            if (idx+1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()
                print(optimizer.param_groups[0]['lr'])
                # val_loss = valid_network(encoder_rtl, text_enc, text_proj, val_rtl_loader, loss_fn_valid, lambda_cs, lambda_mae)
                accelerator.print(
                    f"Batch{j} train_loss={(loss.item()*accumulation_steps):.3f}"
                    # , val_loss={val_loss:.3f}"
                )
                j += 1

            del (emb_ori_rtl, emb_pos_rtl, emb_neg_rtl,loss)
            th.cuda.empty_cache()
        except:
            print("OOM!")
            continue

    epoch_loss /= j

    return epoch_loss, encoder_rtl


def valid_network(encoder_rtl, text_enc, text_proj, val_rtl_loader, loss_fn_valid, lambda_cs, lambda_mae):
    rtl_ori_loader, rtl_pos_loader, rtl_neg_loader, text_loader = val_rtl_loader

    encoder_rtl.eval()


    with th.no_grad():
        for idx, data in enumerate(zip(rtl_ori_loader, rtl_pos_loader, rtl_neg_loader, text_loader)):
            batched_ori_rtl, batched_pos_rtl, batched_neg_rtl, batched_text = data[0], data[1], data[2], data[3]

            text_ori, text_pos, text_neg = batched_text[0], batched_text[1], batched_text[2]

            device = accelerator.device

            ### text embedding ###
            inputs1 = tokenizer(text_ori, return_tensors="pt", padding='longest', truncation=True, max_length=512).to(device)
            inputs2 = tokenizer(text_pos, return_tensors="pt", padding='longest', truncation=True, max_length=512).to(device)
            inputs3 = tokenizer(text_neg, return_tensors="pt", padding='longest', truncation=True, max_length=512).to(device)

            # Compute Embeddings
            outputs1 = accelerator.unwrap_model(text_enc).bert(inputs1.input_ids, attention_mask=inputs1.attention_mask, return_dict=True)
            outputs2 = accelerator.unwrap_model(text_enc).bert(inputs2.input_ids, attention_mask=inputs2.attention_mask, return_dict=True)
            outputs3 = accelerator.unwrap_model(text_enc).bert(inputs3.input_ids, attention_mask=inputs3.attention_mask, return_dict=True)

            embeds1 = outputs1.last_hidden_state[:, 0, :]
            embeds2 = outputs2.last_hidden_state[:, 0, :]
            embeds3 = outputs3.last_hidden_state[:, 0, :]

            text_emb_ori = F.normalize(text_proj(embeds1),dim=-1)
            text_emb_pos = F.normalize(text_proj(embeds2),dim=-1)
            text_emb_neg = F.normalize(text_proj(embeds3),dim=-1)


            #### ori ####
            # graph_node_feature, attn_bias, attn_mask = gt_pe_calculate(batched_ori_rtl)
            (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_ori_rtl, device)
            batched_ori_rtl = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)
            
            # _, emb_ori_rtl = encoder_rtl.encoder(batched_ori_rtl)
            _, emb_ori_rtl = accelerator.unwrap_model(encoder_rtl).encoder(batched_ori_rtl)

            loss_mae_rtl,_,_ = encoder_rtl(batched_ori_rtl)

            #### pos ####
            (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_pos_rtl, device)
            batched_pos_rtl = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)
            _, emb_pos_rtl = accelerator.unwrap_model(encoder_rtl).encoder(batched_pos_rtl)

            del (attn_mask,node_feat,in_degree,out_degree,path_data,dist,batched_pos_rtl)
            th.cuda.empty_cache()
            
            #### neg ####
            (attn_mask,node_feat,in_degree,out_degree,path_data,dist) = all_to_device(batched_neg_rtl, device)
            batched_neg_rtl = (attn_mask,node_feat,in_degree,out_degree,path_data,dist)
            _, emb_neg_rtl = accelerator.unwrap_model(encoder_rtl).encoder(batched_neg_rtl)

            del (attn_mask,node_feat,in_degree,out_degree,path_data,dist,batched_neg_rtl)
            th.cuda.empty_cache()

            loss_mae = loss_mae_rtl

            emb_graph = (emb_ori_rtl, emb_pos_rtl, emb_neg_rtl)
            emb_text = (text_emb_ori, text_emb_pos, text_emb_neg)

            loss = loss_fn_valid(loss_mae, emb_graph, emb_text, None, None, lambda_cs, lambda_mae)
                    
            return loss.item()

def train_val_pipeline():
    batch_size = 4
    # train_net_loader, val_net_loader = load_data_net(batch_size)
    # train_rtl_loader, val_rtl_loader = load_data_rtl(batch_size)

    val_rtl_loader = load_data_valid()

    ### 1. set up model and optimizer ####
    param_rtl, _ = load_param()
    encoder_rtl = build_model_rtl(param_rtl)

    bert_model = 'bert-base-uncased'
    global tokenizer
    tokenizer = BertTokenizer.from_pretrained(bert_model)
    bert_config = BertConfig.from_json_file("/home/coguest5/CircuitFusion/text_enc/bert_cl/scr/config_bert.json")
    text_enc = BertForMaskedLM.from_pretrained(bert_model, config=bert_config)
    embed_dim = 256
    text_width = text_enc.config.hidden_size
    text_proj = nn.Linear(text_width, embed_dim)


    num_epochs = 128
    global accumulation_steps
    total_updates = num_epochs * 256 ### real batch size = 256
    warmup_updates = 20

    optimizer = AdamW([
        {'params': encoder_rtl.parameters(), 'lr': 1e-3, 'eps':1e-8, 'weight_decay':0},
        {'params': text_enc.parameters(), 'lr': 1e-3, 'eps':1e-8, 'weight_decay':0},
        {'params': text_proj.parameters(), 'lr': 1e-3, 'eps':1e-8, 'weight_decay':0}
    ])
    lr_scheduler = get_polynomial_decay_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_updates,
        num_training_steps=total_updates,
        lr_end=1e-4,
        power=1.0,
    )

    train_rtl_loader = load_data_train(batch_size=2)

    (
        encoder_rtl,
        text_enc,
        text_proj,
        optimizer,
        lr_scheduler,
        train_rtl_loader
    ) = accelerator.prepare(
        encoder_rtl, text_enc, text_proj, optimizer, lr_scheduler, train_rtl_loader
    )

    margin_num = 1.0
    lambda_rr = 1.0
    lambda_nn = 1.0
    lambda_cs = 0.4
    lambda_mae = 0.5
    criterion = CMAELoss(margin_rr=margin_num, margin_nn=margin_num, margin_rn=margin_num, margin_nr=margin_num, lamda_rr=lambda_rr, lamda_nn=lambda_nn, lamda_rn=lambda_cs, lamda_mae=lambda_mae)
    criterion_valid = CMAELoss(margin_rr=margin_num, margin_nn=margin_num, margin_rn=margin_num, margin_nr=margin_num, lamda_rr=lambda_rr, lamda_nn=lambda_nn, lamda_rn=lambda_cs, lamda_mae=lambda_mae)

    accelerator.print("Training started")
    
    for epoch in range(num_epochs):
        accelerator.print(f"Epoch {epoch + 1}")
        
        lamda_cs = (1.2-0.4) * (epoch/num_epochs) + 0.4
        lambda_mae = (0.05-0.5) * (epoch/num_epochs) + 0.5
        
        ### data block 4 ###
        accumulation_steps = 128
        epoch_train_loss, encoder_rtl = train_epoch(
            encoder_rtl, text_enc, text_proj, optimizer, train_rtl_loader, val_rtl_loader,\
            lr_scheduler, criterion, criterion_valid, lamda_cs, lambda_mae
        )
        # del train_rtl_loader

        # epoch_val_loss = valid_network(encoder_rtl, text_enc, text_proj, val_rtl_loader, criterion_valid, lambda_cs, lambda_mae)

        accelerator.print(
            f"Epoch={epoch + 1} | train_loss={epoch_train_loss:.3f} | "
            # f"val_loss={epoch_val_loss:.3f}"
        )
        writer.add_scalar('Train Loss', epoch_train_loss, epoch)
        # writer.add_scalar('Valid Loss', epoch_val_loss, epoch)

        ### save whole model each epochs
        accelerator.wait_for_everyone()
        unwrapped_encoder_rtl = accelerator.unwrap_model(encoder_rtl)
        unwrapped_text_enc = accelerator.unwrap_model(text_enc)
        unwrapped_text_proj = accelerator.unwrap_model(text_proj)

        accelerator.save(unwrapped_encoder_rtl, f"{model_save_dir}/encoder_rtl.acc.{epoch}.pt")
        accelerator.save(unwrapped_text_enc, f"{model_save_dir}/text_enc.acc.{epoch}.pt")
        accelerator.save(unwrapped_text_proj, f"{model_save_dir}/text_proj.acc.{epoch}.pt")

        th.save(unwrapped_encoder_rtl, f"{model_save_dir}/graph_enc.{epoch}.pth")
        th.save(unwrapped_text_enc, f"{model_save_dir}/text_enc.{epoch}.pth")
        th.save(unwrapped_text_proj, f"{model_save_dir}/text_proj.{epoch}.pth")
        th.cuda.empty_cache()


if __name__ == "__main__":

    log_dir = f'../log/log_{date}'
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    else:
        os.system(f'rm -r {log_dir}')
        os.mkdir(log_dir)

    model_save_dir = f"../pretrain_model/{date}"
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)

    global writer
    writer = SummaryWriter(log_dir)


    global accumulation_steps
    accumulation_steps = 1
    train_val_pipeline()