import config
import torch
import torch.nn.functional as F
from torch import optim
import numpy as np
import time
import wandb
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
import scipy.sparse as sp
from temperature import tune_temp
from conformal import * 
from sklearn import preprocessing, model_selection
from torch_geometric.loader import DataLoader
from utils import average_models, make_gnn_model, evaluate_model, device

def hide_graph(data, hide_ratio=0.8):
    hide_len = int(len(torch.where(data.val_mask)[0]) * hide_ratio)
    val_ids = torch.where(data.val_mask)[0]
    unif = torch.ones(val_ids.shape[0])
    random_indices = unif.multinomial(hide_len, replacement=False)
    hide_ids = val_ids[random_indices]
    node_ids = torch.arange(0, len(data.y), dtype=int)

    combined = torch.cat((node_ids, hide_ids))
    uniques, counts = combined.unique(return_counts=True)
    difference = uniques[counts == 1]

    new_edges = subgraph(difference, data.edge_index)[0]

    val_combined = torch.cat((val_ids, hide_ids))
    val_uniques, val_counts = val_combined.unique(return_counts=True)
    val_difference = val_uniques[val_counts == 1]

    hide_mask = torch.zeros(len(data.y), dtype=torch.bool)
    hide_mask[hide_ids] = True

    val_mask = torch.zeros(len(data.y), dtype=torch.bool)
    val_mask[val_difference] = True

    new_data = Data(x=data.x, edge_index=new_edges, y=data.y, train_mask=data.train_mask, val_mask=val_mask, test_mask=data.test_mask, hide_mask=hide_mask)

    return new_data

def greedy_loss(pred_feats, true_feats,pred_missing,true_missing, args):
    true_missing=true_missing.cpu()
    pred_missing = pred_missing.cpu()
    loss=torch.zeros(pred_feats.shape, requires_grad=True)
    loss=loss.to(device)
    pred_len=len(pred_feats)
    pred_missing_np = pred_missing.detach().numpy().reshape(-1).astype(np.int32)
    true_missing_np = true_missing.detach().numpy().reshape(-1).astype(np.int32)
    true_missing_np = np.clip(true_missing_np,0, args["num_pred"])
    pred_missing_np = np.clip(pred_missing_np, 0, args["num_pred"])
    for i in range(pred_len):
        for pred_j in range(min(args["num_pred"], pred_missing_np[i])):
            if true_missing_np[i]>0:
                if isinstance(true_feats[i][true_missing_np[i]-1], np.ndarray):
                    true_feats_tensor = torch.tensor(true_feats[i][true_missing_np[i]-1])
                    true_feats_tensor=true_feats_tensor.to(device)
                else:
                    true_feats_tensor=true_feats[i][true_missing_np[i]-1]
                    true_feats_tensor=true_feats_tensor.to(device)
                loss[i][pred_j] += F.mse_loss(pred_feats[i][pred_j].unsqueeze(0).float(),
                                                  true_feats_tensor.unsqueeze(0).float()).squeeze(0)

                for true_k in range(min(args["num_pred"], true_missing_np[i])):
                    if isinstance(true_feats[i][true_k], np.ndarray):
                        true_feats_tensor = torch.tensor(true_feats[i][true_k])
                        true_feats_tensor = true_feats_tensor.to(device)
                    else:
                        true_feats_tensor = true_feats[i][true_k]
                        true_feats_tensor = true_feats_tensor.to(device)

                    loss_ijk=F.mse_loss(pred_feats[i][pred_j].unsqueeze(0).float(),
                                        true_feats_tensor.unsqueeze(0).float()).squeeze(0)
                    if torch.sum(loss_ijk)<torch.sum(loss[i][pred_j].data):
                        loss[i][pred_j]=loss_ijk
            else:
                continue
    return loss

def get_generator_targets(org_data, hidden_data, unique_class_labels, args):

    node_ids = torch.arange(org_data.x.shape[0])#org_data.edge_index[0].unique()   # unique returns ordered list    

    targets_missing_node_count = []
    targets_missing_node_feat = []
    targets_missing_node_label = []

    feat_shape = org_data.x[0].shape[0]

    org_edges = org_data.edge_index
    hid_edges = hidden_data.edge_index

    org_feat = org_data.x.cpu()

    num_pred = args["num_pred"]
    num_classes = len(unique_class_labels)

    targets_missing_node_label = preprocessing.label_binarize(org_data.y[node_ids].cpu(), classes=unique_class_labels)

    for id_i in node_ids:
        org_neighbor_ids = set(org_edges[1][org_edges[0] == id_i].tolist())
        hide_neighbor_ids = set(hid_edges[1][hid_edges[0] == id_i].tolist())
        missing_ids = org_neighbor_ids.difference(hide_neighbor_ids)
        missing_len = len(missing_ids)

        if missing_len > 0:
            if len(missing_ids) <= num_pred:
                zeros = np.zeros((max(0, num_pred - missing_len), feat_shape))
                missing_feat_all = np.vstack((org_feat[list(missing_ids)], zeros)).\
                    reshape((1, num_pred, feat_shape))
            else:
                missing_feat_all = np.copy(org_feat[list(missing_ids)[:num_pred]]).\
                    reshape((1, num_pred, feat_shape))
        else:
            missing_feat_all = np.zeros((1, num_pred, feat_shape))

        targets_missing_node_count.append(missing_len)
        targets_missing_node_feat.append(missing_feat_all)

    hidden_graph_node_ids = hidden_data.edge_index[0].unique().clone().detach().reshape((-1,1))
    hidden_graph_node_ids_mask = torch.zeros((len(org_data.y),1), dtype=torch.bool).to(device)
    hidden_graph_node_ids_mask[hidden_graph_node_ids] = True

    targets_missing_node_count = torch.tensor(targets_missing_node_count).reshape((-1,1)).to(device)
    targets_missing_node_feat = np.array(targets_missing_node_feat)  # Convert to a single numpy array
    targets_missing_node_feat = torch.tensor(targets_missing_node_feat).reshape((-1, num_pred, feat_shape)).to(device)
    targets_missing_node_label = torch.tensor(targets_missing_node_label).reshape((-1, num_classes)).to(device)

    return hidden_graph_node_ids_mask, targets_missing_node_count, targets_missing_node_feat, targets_missing_node_label

def accuracy_missing(output, labels):
    output = output.cpu()
    preds = output.detach().numpy().astype(int)
    correct = 0.0
    for pred,label in zip(preds,labels):
        if int(pred) == int(label):
            correct += 1.0
    return correct / len(labels)

def accuracy(pred,true):
    acc = 0.0
    for predi,truei in zip(pred,true):
        if torch.argmax(predi) == torch.argmax(truei):
            acc += 1.0
    acc /= len(pred)
    return acc

def train_fedgen(local_gen_list:list, optim_list:list, feat_shape:int, hidden_client_datas:list, original_client_datas:list, unique_class_labels:list, args: dict = None):
    t=time.time()

    num_pred = args["num_pred"] 
    num_owners = args["num_client"]
    gen_epochs = args["rounds_stage1"]

    num_classes = len(unique_class_labels)
    hidden_graph_node_ids_masks, all_targets_missing, all_targets_feat, all_targets_subj = [], [], [], [] 

    for i in range(num_owners):
        out = get_generator_targets(original_client_datas[i], hidden_client_datas[i], unique_class_labels, args)

        hidden_graph_node_ids_masks.append(out[0])
        all_targets_missing.append(out[1])
        all_targets_feat.append(out[2])
        all_targets_subj.append(out[3])

    for epoch in range(gen_epochs):
        for i in range(num_owners):
            local_model=local_gen_list[i] 
            local_model.train()
            optim_list[i].zero_grad() 

            client_data = hidden_client_datas[i].to(device)

            # I may need to convert these to dataloader objects
            # loader = DataLoader(dataset= [cli_data], batch_size=batchSize, pin_memory=False)
            input_feat = client_data.x
            input_edge = client_data.edge_index
            # node_ids = client_data.edge_index.unique()[0]
            node_len = client_data.y.shape[0]

            output_missing, output_feat, output_nc, out_softmax = local_model(input_feat, input_edge)
            output_missing = torch.flatten(output_missing)
            output_feat = output_feat.view(node_len, num_pred, feat_shape)
            output_nc = output_nc.view(node_len + node_len * num_pred, num_classes)

            loss_train_missing = F.smooth_l1_loss(output_missing[client_data.train_mask].float(),
                                                  all_targets_missing[i][client_data.train_mask].reshape(-1).float())

            loss_train_feat = greedy_loss(output_feat[client_data.train_mask], all_targets_feat[i][client_data.train_mask],
                                          output_missing[client_data.train_mask], all_targets_missing[i][client_data.train_mask], args).unsqueeze(0).mean().float()

            true_nc_label = torch.argmax(all_targets_subj[i][client_data.train_mask], dim=1).view(-1)

            class_predictions = output_nc[torch.where(client_data.train_mask)]
            loss_train_label = F.cross_entropy(class_predictions, true_nc_label)

            acc_train_missing = accuracy_missing(output_missing[client_data.train_mask], all_targets_missing[i][client_data.train_mask])

            acc_train_nc = accuracy(output_nc[torch.where(client_data.train_mask)], all_targets_subj[i][client_data.train_mask])

            fedgen_loss = (loss_train_missing + loss_train_feat + loss_train_label).float()
        
            if epoch % 50 == 0:
                print('Data owner ' + str(i), ' Epoch: {:04d}'.format(epoch + 1),
                    'loss_train: {:.4f}'.format(fedgen_loss.item()), 'missing_train: {:.4f}'.format(acc_train_missing),
                    'nc_train: {:.4f}'.format(acc_train_nc), 'loss_miss: {:.4f}'.format(loss_train_missing.item()),
                    'loss_nc: {:.4f}'.format(loss_train_label.item()), 'loss_feat: {:.4f}'.format(loss_train_feat.item()),
                    'time: {:.4f}s'.format(time.time() - t))
            # wandb.log({"loss_train_missing" : loss_train_missing.item(),
            #            "loss_train_label" :loss_train_label.item() , "loss_feat" : loss_train_feat.item() })
            for j in range(num_owners):
                if j != i:   
                    client_j_node_ids = original_client_datas[j].edge_index.unique()
                    choice = np.random.choice(len(client_j_node_ids), len(original_client_datas[i].train_mask))
                    others_ids = client_j_node_ids[choice]
                    global_target_feat = []

                    for c_i in others_ids:
                        edges = original_client_datas[j].edge_index
                        neighbors_ids = edges[1][edges[0] == c_i].tolist()

                        while len(neighbors_ids)==0:
                            c_i = np.random.choice(len(list(client_j_node_ids)),1)[0]
                            id_i = client_j_node_ids[c_i]
                            neighbors_ids = edges[1][edges[0] == id_i].tolist()

                        choice_i = np.random.choice(neighbors_ids, num_pred)
                        for ch_i in choice_i:
                            global_target_feat.append(original_client_datas[j].x[ch_i])

                    global_target_feat = torch.stack(global_target_feat).reshape((len(client_data.train_mask), num_pred, feat_shape))
                    loss_train_feat_other = greedy_loss(output_feat[client_data.train_mask], global_target_feat,
                                                        output_missing[client_data.train_mask], all_targets_missing[i][client_data.train_mask], args).unsqueeze(0).mean().float()
                    fedgen_loss += loss_train_feat_other

            fedgen_loss = 1.0 / num_owners * fedgen_loss
            # wandb.log({ "fedgen_loss" : fedgen_loss.item() })
            fedgen_loss.backward()
            optim_list[i].step()

    return

def fill_graph(hidden_graph:Data, original_graph:Data, missing, new_feats, feat_shape, args = None):
    num_pred = args["num_pred"]

    new_feats = new_feats.reshape((-1, num_pred, feat_shape))
    org_graph_node_ids = original_graph.edge_index.unique()
    hidden_graph_node_ids = hidden_graph.edge_index.unique()
    fill_node_feats = []
    org_feats = original_graph.x

    for i in range(len(org_graph_node_ids)):
        fill_node_feats.append(np.asarray(org_feats[i].reshape(-1).cpu()))

    hidden_graph_edges = torch.clone(hidden_graph.edge_index)
    fill_edges_source = [edge for edge in hidden_graph_edges[0]]
    fill_edges_target = [edge for edge in hidden_graph_edges[1]]

    start_id = len(org_graph_node_ids)
    dif = 0
    for new_i in hidden_graph_node_ids:
        if int(missing[new_i]) > 0:
            new_ids_i = np.arange(start_id, start_id + min(num_pred, int(missing[new_i])))

            i_pred = 0
            for i in new_ids_i:
                if isinstance(new_feats[new_i][i_pred], np.ndarray) == False:
                    new_feats = new_feats.cpu().detach().numpy()
                fill_node_feats.append(np.asarray(new_feats[new_i][i_pred].reshape(-1)))
                i_pred += 1
                fill_edges_source.append(new_i)
                fill_edges_target.append(int(i))
                fill_edges_source.append(int(i))
                fill_edges_target.append(new_i)
                dif += 1

            start_id = start_id + min(num_pred, int(missing[new_i]))

    fill_node_feats_np = torch.tensor(fill_node_feats).reshape((-1,feat_shape)).to(device)

    new_data = Data(x=fill_node_feats_np, edge_index=torch.tensor([fill_edges_source, fill_edges_target]).to(device), y=hidden_graph.y, train_mask=hidden_graph.train_mask, val_mask=hidden_graph.train_mask, test_mask=hidden_graph.test_mask, hide_mask=hidden_graph.hide_mask)

    return new_data, dif

def train_fedSagePlusClassifier(classifier_list:list, local_gen_list:list, hidden_client_datas:list, original_client_datas:list, feat_shape:int, unique_class_labels:list, args: dict = None):
    num_pred = args["num_pred"]
    num_owners = args["num_client"]
    epoch_classifier = args["rounds_stage1"]
    lr = args["local_lr_stage1"]
    weight_decay = args["weight_decay"]
    batch_size = args["batch_size"]
    arch = args["architecture"] 

    mended_graph_list = []
    mended_graph_loaders = []
    num_classes = len(unique_class_labels)
    differences = []

    for owner_i in range(num_owners):
        hidden_client_data = hidden_client_datas[owner_i]
        classifier = classifier_list[owner_i]

        input_feat = hidden_client_data.x
        input_edge = hidden_client_data.edge_index
        pred_missing, pred_feats, _, _ = local_gen_list[owner_i](input_feat, input_edge)

        filled_data, dif = fill_graph(hidden_client_data, original_client_datas[owner_i], pred_missing, pred_feats, feat_shape, args)
        filled_data = filled_data.to(device)

        mended_graph_list.append(filled_data)
        loader = DataLoader(dataset= [filled_data], batch_size=batch_size, pin_memory=False)
        mended_graph_loaders.append(loader)
        differences.append(dif)

    global_model = make_gnn_model(architecture = arch, in_channels = feat_shape, num_classes = num_classes).to(device)

    for model in classifier_list:
            model.load_state_dict(global_model.state_dict())

    optimizers = [torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for model in classifier_list]

    for epoch in range(epoch_classifier):

        for owner_i in range(num_owners):
            classifier_list[owner_i].load_state_dict(global_model.state_dict())

        total_loss = 0
        for i in range(num_owners):   
            client_model = classifier_list[i]                     
            client_model.train()
            optimizers[i].zero_grad() 
            client_model.to(device)
            
            _, out, _ = client_model(mended_graph_list[i].x, mended_graph_list[i].edge_index)
            loss = F.nll_loss(out[torch.where(mended_graph_list[i].train_mask)], mended_graph_list[i].y[mended_graph_list[i].train_mask])

            loss.backward()
            optimizers[i].step()
            total_loss += loss.item() 

        total_loss /= num_owners

        # average params across neighbors
        average_models(global_model, classifier_list)

        # evaluate clients' training performance
        clients_train_loss, clients_train_acc = 0, 0
        for i in range(num_owners):
            train_loss, train_acc = evaluate_model(
                model = classifier_list[i],
                data_loader = mended_graph_loaders[i],
                mask = torch.where(mended_graph_list[i].train_mask),
            )
            clients_train_loss += train_loss
            clients_train_acc += train_acc
        clients_train_loss /= num_owners
        clients_train_acc /= num_owners

    for i in range(num_owners):
        val_msk, tst_msk = torch.where(mended_graph_list[i].val_mask), torch.where(mended_graph_list[i].test_mask)
        _, valAcc = evaluate_model(model = classifier_list[i], data_loader = mended_graph_loaders[i], mask =val_msk)
        print("stage 1 val, ", valAcc)
        _, testAcc = evaluate_model(model = classifier_list[i], data_loader = mended_graph_loaders[i], mask = tst_msk)
        print("stage 1 tst, ", testAcc)

    print("FedSage+ end!")

    return mended_graph_loaders, mended_graph_list, global_model, differences

