import torch

import warnings
import time
import os.path as osp
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from config import args
from models import my_model
from torch.optim import Adam
from utils import set_seed
from models.my_model import MyModel
from tasks.utils import accuracy
from datasets.graph_fl_dataset import GraphFLDataset
warnings.filterwarnings('ignore')

main_normalize_train = 1


my_model.message_op = 'sum'


homo_lr = 0.01
my_model.homo_drop = 0.5
my_model.homo_layers = 2
homo_kd_ratio = 0.03


hete_lr = 0.1
my_model.hete_drop = 0.7
my_model.hete_negative_slope = 0
my_model.hete_re_smooth = 0.1
my_model.hete_re_temperature = 0.5

num_clients = 3
if __name__ == "__main__":
    gpu_id = 0
    # device = "cpu"
    device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')
    set_seed(args.seed)

    datasets = GraphFLDataset(
        root='./datasets',
        name='PubMed',
        sampling='Louvain',
        num_clients=num_clients,
        analysis_local_subgraph = False,
        analysis_global_graph = False
    )

    print("| ★  Start Local Client Personalized Training")
    global_normalize_record = {"acc_val_mean": 0, "acc_val_std": 0, "acc_test_mean": 0, "acc_test_std" : 0}

    t_total = time.time()
    for i in range(len(datasets.subgraphs)):
        subgraph = datasets.subgraphs[i]
        subgraph.y = subgraph.y.to(device)
        local_normalize_record = {"acc_val": [], "acc_test": []}
        for _ in range(main_normalize_train):
            gmodel = torch.load(osp.join("./model_weights", "{}_Client{}_{}_model.pt".format(datasets.name, datasets.num_clients, datasets.sampling)))
            gmodel.preprocess(subgraph.adj, subgraph.x)
            gmodel = gmodel.to(device)
            nodes_embedding = gmodel.model_forward(range(subgraph.num_nodes), device).detach().cpu()
            acc_val = accuracy(nodes_embedding[subgraph.val_idx], subgraph.y[subgraph.val_idx])
            acc_test = accuracy(nodes_embedding[subgraph.test_idx], subgraph.y[subgraph.test_idx])
            model = MyModel(prop_steps=3, 
            feat_dim=datasets.input_dim, 
            hidden_dim=args.main_hidden_dim,
            output_dim=datasets.output_dim)
            model.non_para_lp(subgraph=subgraph, nodes_embedding=nodes_embedding, x=subgraph.x, device=device)
            model.preprocess(adj=subgraph.adj)
            model = model.to(device)
            loss_ce_fn = nn.CrossEntropyLoss()
            loss_mse_fn = nn.MSELoss()
            optimizer = Adam(model.parameters(), lr=homo_lr, weight_decay=args.main_homo_weight_decay) if model.homo else Adam(model.parameters(), lr=hete_lr, weight_decay=args.main_hete_weight_decay)
            epochs = args.main_homo_num_epochs if model.homo else args.main_hete_num_epochs
            best_val = 0.
            best_test = 0.
            best_epoch = 0
            for epoch in range(epochs):
                t = time.time()
                model.train()
                optimizer.zero_grad()
                if model.homo:
                    local_smooth_emb, local_emb = model.homo_forward(device)
                    loss_train1 = loss_ce_fn(local_smooth_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train2 = loss_ce_fn(local_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train3 = loss_mse_fn(local_emb, local_smooth_emb)
                    if gmodel.use_graph_op:
                        loss_train = loss_train1 + loss_train2 + homo_kd_ratio * loss_train3
                        train_output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_emb.data, 1)) / 2
                    else:
                        loss_train = loss_train1
                        train_output = F.softmax(local_smooth_emb.data, 1)
                    acc_train = accuracy(train_output[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train.backward()
                    optimizer.step()
                    model.eval()
                    local_smooth_emb, local_emb = model.homo_forward(device)
                    if gmodel.use_graph_op:
                        output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_emb.data, 1)) / 2
                    else:
                        output = F.softmax(local_smooth_emb.data, 1)
                    acc_val = accuracy(output[subgraph.val_idx], subgraph.y[subgraph.val_idx])
                    acc_test = accuracy(output[subgraph.test_idx], subgraph.y[subgraph.test_idx])
                else:
                    local_smooth_emb, local_message_propagation, local_emb = model.hete_forward(device)
                    train_output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_message_propagation.data, 1) + F.softmax(local_emb.data, 1)) / 3
                    acc_train = accuracy(train_output[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train1 = loss_ce_fn(local_smooth_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train2 = loss_ce_fn(local_message_propagation[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train3 = loss_ce_fn(local_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train = loss_train1 + loss_train2 + loss_train3
                    loss_train.backward()
                    optimizer.step()
                    model.eval()
                    local_smooth_emb, local_message_propagation, local_emb = model.hete_forward(device)
                    output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_message_propagation.data, 1) + F.softmax(local_emb.data, 1)) / 3
                    acc_val = accuracy(output[subgraph.val_idx], subgraph.y[subgraph.val_idx])
                    acc_test = accuracy(output[subgraph.test_idx], subgraph.y[subgraph.test_idx])
                if acc_val > best_val:
                    best_epoch = epoch + 1
                    best_val = acc_val
                    best_test = acc_test
            print("| ▶ Final Output, Model Para: {}M, Best Epoch: {}, Best Val Acc: {}, Test Acc: {}".format(model.total_trainable_params, best_epoch, round(best_val, 4), round(best_test, 4)))
            print("| ")
            local_normalize_record["acc_val"].append(best_val)
            local_normalize_record["acc_test"].append(best_test)
        global_normalize_record["acc_val_mean"] += np.mean(local_normalize_record["acc_val"]) * subgraph.num_nodes / datasets.global_data.num_nodes
        global_normalize_record["acc_val_std"] += np.std(local_normalize_record["acc_val"], ddof=1) * subgraph.num_nodes / datasets.global_data.num_nodes
        global_normalize_record["acc_test_mean"] += np.mean(local_normalize_record["acc_test"]) * subgraph.num_nodes / datasets.global_data.num_nodes
        global_normalize_record["acc_test_std"] += np.std(local_normalize_record["acc_test"], ddof=1) * subgraph.num_nodes / datasets.global_data.num_nodes
    print("| ★  Normalize Train Completed")
    print("| Normalize Train: {}, Total Time Elapsed: {:.4f}s".format(args.main_normalize_train, time.time() - t_total))
    print("| Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(round(global_normalize_record["acc_val_mean"], 4), round(global_normalize_record["acc_val_std"], 4), round(global_normalize_record["acc_test_mean"], 4), round(global_normalize_record["acc_test_std"], 4)))

            
