import torch
import warnings
import time
import optuna
import os.path as osp
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from models import my_model
from config import args
from torch.optim import Adam
from utils import set_seed
from models.my_model import MyModel
from tasks.utils import accuracy
from datasets.graph_fl_dataset import GraphFLDataset
warnings.filterwarnings('ignore')

def objective(trial):
    set_seed(args.seed)
    my_model.message_op = trial.suggest_categorical('message_op', ['last', 'over', 'mean', 'sum', 'concat'])

    homo_lr = trial.suggest_categorical('homo_lr', [1e-3, 1e-2, 1e-1])
    my_model.homo_drop = trial.suggest_discrete_uniform('homo_drop',0.5, 0.8, 0.1)
    my_model.homo_layers = trial.suggest_int('homo_layers', 1, 2)
    homo_kd_ratio = trial.suggest_discrete_uniform('homo_kd_ratio',0.01, 0.1, 0.02)


    hete_lr = trial.suggest_categorical('hete_lr', [1e-3, 1e-2, 1e-1])
    my_model.hete_drop = trial.suggest_discrete_uniform('hete_drop',0.5,0.9, 0.1)
    my_model.hete_negative_slope = trial.suggest_discrete_uniform('hete_negative_slope', 0, 0.2, 0.02)
    my_model.hete_re_smooth = trial.suggest_discrete_uniform('hete_re_smooth', 0, 1, 0.1)
    my_model.hete_re_temperature = trial.suggest_discrete_uniform('hete_re_temperature', 0.5, 1.5, 0.1)
    datasets = GraphFLDataset(
        root='./datasets',
        name='Cora',
        sampling='Noniid',
        num_clients=3,
        analysis_local_subgraph = False,
        analysis_global_graph = False
    )
    print("| ★  Start Local Client Personalized Training")
    global_normalize_record = {"acc_val_mean": 0, "acc_val_std": 0, "acc_test_mean": 0, "acc_test_std" : 0}
    t_total = time.time()
    for i in range(len(datasets.subgraphs)):
        subgraph = datasets.subgraphs[i]
        subgraph.y = subgraph.y.to(device)
        local_normalize_record = {"acc_val": [], "acc_test": []}
        for _ in range(1):
            gmodel = torch.load(osp.join("./model_weights", "{}_Client{}_{}_model.pt".format(datasets.name, datasets.num_clients, datasets.sampling)))
            gmodel.preprocess(subgraph.adj, subgraph.x)
            gmodel = gmodel.to(device)
            nodes_embedding = gmodel.model_forward(range(subgraph.num_nodes), device).detach().cpu()
            acc_val = accuracy(nodes_embedding[subgraph.val_idx], subgraph.y[subgraph.val_idx])
            acc_test = accuracy(nodes_embedding[subgraph.test_idx], subgraph.y[subgraph.test_idx])
            print("| ▷ Evaluate Nodes Embedding: Client ID: {}, Val Acc: {}, Test Acc: {}".format(i+1, round(acc_val, 4), round(acc_test, 4)))
            model = MyModel(prop_steps=3, 
            feat_dim=datasets.input_dim, 
            hidden_dim=64,
            output_dim=datasets.output_dim)
            model.non_para_lp(subgraph=subgraph, nodes_embedding=nodes_embedding, x=subgraph.x, device=device)
            model.preprocess(adj=subgraph.adj)
            model = model.to(device)
            subgraph.y = subgraph.y.to(device)
            loss_ce_fn = nn.CrossEntropyLoss()
            loss_mse_fn = nn.MSELoss()
            optimizer = Adam(model.parameters(), lr=homo_lr, weight_decay=5e-4) if model.homo else Adam(model.parameters(), lr=hete_lr, weight_decay=5e-4)
            epochs = 500 if model.homo else 2000
            best_val = 0.
            best_test = 0.
            best_epoch = 0
            for epoch in range(epochs):
                t = time.time()
                model.train()
                optimizer.zero_grad()
                if model.homo:
                    local_smooth_emb, local_emb = model.homo_forward(device)
                    loss_train1 = loss_ce_fn(local_smooth_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train2 = loss_ce_fn(local_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train3 = loss_mse_fn(local_emb, local_smooth_emb)
                    if gmodel.use_graph_op:
                        loss_train = loss_train1 + loss_train2 + homo_kd_ratio * loss_train3
                        train_output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_emb.data, 1)) / 2
                    else:
                        loss_train = loss_train1
                        train_output = F.softmax(local_smooth_emb.data, 1)
                    acc_train = accuracy(train_output[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train.backward()
                    optimizer.step()
                    model.eval()
                    local_smooth_emb, local_emb = model.homo_forward(device)
                    if gmodel.use_graph_op:
                        output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_emb.data, 1)) / 2
                    else:
                        output = F.softmax(local_smooth_emb.data, 1)
                    acc_val = accuracy(output[subgraph.val_idx], subgraph.y[subgraph.val_idx])
                    acc_test = accuracy(output[subgraph.test_idx], subgraph.y[subgraph.test_idx])
                else:
                    local_smooth_emb, local_message_propagation, local_emb = model.hete_forward(device)
                    train_output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_message_propagation.data, 1) + F.softmax(local_emb.data, 1)) / 3
                    acc_train = accuracy(train_output[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train1 = loss_ce_fn(local_smooth_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train2 = loss_ce_fn(local_message_propagation[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train3 = loss_ce_fn(local_emb[subgraph.train_idx], subgraph.y[subgraph.train_idx])
                    loss_train = loss_train1 + loss_train2 + loss_train3
                    loss_train.backward()
                    optimizer.step()
                    model.eval()
                    local_smooth_emb, local_message_propagation, local_emb = model.hete_forward(device)
                    output = (F.softmax(local_smooth_emb.data, 1) + F.softmax(local_message_propagation.data, 1) + F.softmax(local_emb.data, 1)) / 3
                    acc_val = accuracy(output[subgraph.val_idx], subgraph.y[subgraph.val_idx])
                    acc_test = accuracy(output[subgraph.test_idx], subgraph.y[subgraph.test_idx])
                if acc_val > best_val:
                    best_epoch = epoch + 1
                    best_val = acc_val
                    best_test = acc_test
            print("| ▶ Final Output, Model Para: {}M, Best Epoch: {}, Best Val Acc: {}, Test Acc: {}".format(model.total_trainable_params, best_epoch, round(best_val, 4), round(best_test, 4)))
            print("| ")
            local_normalize_record["acc_val"].append(best_val)
            local_normalize_record["acc_test"].append(best_test)
        global_normalize_record["acc_val_mean"] += np.mean(local_normalize_record["acc_val"]) * subgraph.num_nodes / datasets.global_data.num_nodes
        global_normalize_record["acc_val_std"] += np.std(local_normalize_record["acc_val"], ddof=1) * subgraph.num_nodes / datasets.global_data.num_nodes
        global_normalize_record["acc_test_mean"] += np.mean(local_normalize_record["acc_test"]) * subgraph.num_nodes / datasets.global_data.num_nodes
        global_normalize_record["acc_test_std"] += np.std(local_normalize_record["acc_test"], ddof=1) * subgraph.num_nodes / datasets.global_data.num_nodes
    print("| ★  Normalize Train Completed")
    print("| Normalize Train: {}, Total Time Elapsed: {:.4f}s".format(1, time.time() - t_total))
    print("| Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(round(global_normalize_record["acc_val_mean"], 4), round(global_normalize_record["acc_val_std"], 4), round(global_normalize_record["acc_test_mean"], 4), round(global_normalize_record["acc_test_std"], 4)))
    print("| ")
    trial.suggest_uniform('MeanVal', round(global_normalize_record["acc_val_mean"], 4), round(global_normalize_record["acc_val_mean"], 4))
    trial.suggest_uniform('MeanTest', round(global_normalize_record["acc_test_mean"], 4), round(global_normalize_record["acc_test_mean"], 4))
    return round(global_normalize_record["acc_test_mean"], 4)

if __name__ == "__main__":
    gpu_id = 0
    # device = "cpu"
    device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')
    set_seed(args.seed)

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=200)

    print("Number of finished trials: ", len(study.trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))