import os
import yaml
from pathlib import Path
import torch
import argparse
import dgl
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

import json

from data_loader import load_data_multi_devices
from data_splitter import split_data
from utils import compute_loss, compute_auc, compute_mr, zero_copy, increment_dir, weighted_state_alpha, alpha_update, inference_personal, inv_d, get_rules, get_recommendations, combine_triple_inv, get_hit_rate, compute_mrr
from model import GraphSAGE, GCN, GAT, MLPPredictor, HeteroDotProductPredictor, HeteroMLPPredictor
from comms import fedAvg, fedGate
from training import local_fedAvg, test_local_models, test_personal_models, test_global_model




def train_centered(args):
    tb_writer = SummaryWriter(log_dir=args.log_dir)
    if args.dataset == 'wyze':
        user_graphs, all_trigger_actions, all_devices, user_device_id_to_node_id = load_data_multi_devices(args.dataset,sys_device=args.device)
        
        user_graphs_list, _ = dgl.load_graphs("wyze_rule/usergraphs.bin")

        user_id_dict = json.load(open('wyze_rule/user_id_dict.json', 'r'))
        user_graphs = dict()
        user_ids = list(user_id_dict.keys())

        for i in range(len(user_ids)):
            user_graphs[user_ids[i]] = user_graphs_list[i]



        all_trigger_actions = json.load(open('wyze_rule/all_trigger_actions.json', 'r'))
        all_devices = json.load(open('wyze_rule/all_devices.json', 'r'))
        user_device_id_to_node_id = json.load(open('wyze_rule/user_device_id_to_node_id.json', 'r'))

    train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs = split_data(user_graphs, all_trigger_actions, args.more_negative)
    if args.model_type == 'graphsage':
        model = GraphSAGE(len(set(all_devices.values())), 32).to(args.device) #feature dim: len(set(all_devices.values()))
    elif args.model_type == 'gcn':
        model = GCN(len(set(all_devices.values())), 32).to(args.device) 
    elif args.model_type == 'gat':
        model = GAT(len(set(all_devices.values())), 32).to(args.device) 
        
    pred = HeteroMLPPredictor(32, len(set(all_trigger_actions.values()))).to(args.device)
    optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=args.learning_rate)
    # optimizer = torch.optim.SGD(itertools.chain(model.parameters(), pred.parameters()), lr=args.learning_rate)

    # ----------- 4. training -------------------------------- #


    for e in range(args.num_comms):
        model.train()
        pred.train()
        # forward
        loss = None
        for user_index in train_gs:
            train_g = train_gs[user_index]
            train_pos_g = train_pos_gs[user_index]
            train_neg_g = train_neg_gs[user_index]
            
            h = model(train_g, train_g.ndata['feat'])
            pos_score = pred(train_pos_g, h)[list(range(len(train_pos_g.edata['etype']))), train_pos_g.edata['etype']]
            
            neg_score = pred(train_neg_g, h)[list(range(len(train_neg_g.edata['etype']))), train_neg_g.edata['etype']]


            if loss == None:
                loss = compute_loss(pos_score, neg_score)
            else:
                loss += compute_loss(pos_score, neg_score)
            

        tb_writer.add_scalar('Train/Loss',loss.item() / len(train_gs),e-1)  #-1 since it is before backward
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('In epoch {}, loss: {}'.format(e-1, loss.item() / len(train_gs)))

        # ----------- 5. check results ------------------------ #
        
        if (e + 1) % 5 == 0:
            global_train_loss, global_train_AUC, global_train_MR, global_test_loss, global_test_AUC, global_test_MR = test_global_model(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, model, pred)


            # ----------- 5. check results ------------------------ #
            tb_writer.add_scalar('Global Train/Loss', global_train_loss, e)
            tb_writer.add_scalar('Global Train/AUC', global_train_AUC, e)
            tb_writer.add_scalar('Global Train/POS_MR', global_train_MR, e)

            tb_writer.add_scalar('Global Test/Loss', global_test_loss, e)
            tb_writer.add_scalar('Global Test/AUC', global_test_AUC, e)
            tb_writer.add_scalar('Global Test/POS_MR', global_test_MR, e)

    torch.save(model.state_dict(), args.dataset + "central_model_" + args.model_type)
    torch.save(pred.state_dict(), args.dataset + "central_pred_" + args.model_type)
    
    for i in [5, 10, 15, 20, 25, 30, 35, 40]:
        hit_rate = get_hit_rate(train_gs, test_pos_gs, model, pred, all_devices, i, args.dataset)
        tb_writer.add_scalar('Global Test/Hit Rate', hit_rate, i)
    
    return

def train_federated(args):
    tb_writer = SummaryWriter(log_dir=args.log_dir)
    

    user_graphs_list, _ = dgl.load_graphs("wyze_rule/usergraphs.bin")

    user_id_dict = json.load(open('wyze_rule/user_id_dict.json', 'r'))
    user_graphs = dict()
    user_ids = list(user_id_dict.keys())

    for i in range(len(user_ids)):
        user_graphs[user_ids[i]] = user_graphs_list[i]



    all_trigger_actions = json.load(open('wyze_rule/all_trigger_actions.json', 'r'))
    all_devices = json.load(open('wyze_rule/all_devices.json', 'r'))
    user_device_id_to_node_id = json.load(open('wyze_rule/user_device_id_to_node_id.json', 'r'))
    
    
    train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs = split_data(user_graphs, all_trigger_actions, args.more_negative)
    global_model = GraphSAGE(len(set(all_devices.values())), 32).to(args.device) #feature dim: len(set(all_devices.values()))
    global_pred = HeteroMLPPredictor(32, len(set(all_trigger_actions.values()))).to(args.device)
    models = dict()
    predictors = dict()
    if args.fedtype=='fedgate':
        models_delta = dict()
        predictors_delta = dict()
    if args.fedtype=='fedrepgate':
        models_delta = dict()
        #predictors_delta = dict()
    if args.fedtype=='lgfedavggate':
        #models_delta = dict()
        predictors_delta = dict()
    
    for user_index in train_gs:
        models[user_index] = deepcopy(global_model).to(args.device)
        models[user_index].train()

        predictors[user_index] = deepcopy(global_pred).to(args.device)
        predictors[user_index].train()
        
        if args.fedtype=='fedgate':
            models_delta[user_index] = zero_copy(global_model)
            predictors_delta[user_index] = zero_copy(global_pred)
        
        if args.fedtype=='fedrepgate':
            models_delta[user_index] = zero_copy(global_model)
            #predictors_delta[user_index] = zero_copy(global_pred)
        if args.fedtype=='lgfedavggate':
            #models_delta[user_index] = zero_copy(global_model)
            predictors_delta[user_index] = zero_copy(global_pred)
            
    optimizers = dict()
    for user_index in train_gs:
            optimizers[user_index] = torch.optim.Adam(itertools.chain(models[user_index].parameters(), predictors[user_index].parameters()), lr=args.learning_rate) #add regularizations 
            # optimizers[user_index] = torch.optim.SGD(itertools.chain(models[user_index].parameters(), predictors[user_index].parameters()), lr=args.learning_rate)


    if args.personaltype == "APFL":
        personal_models = dict()
        personal_predictors = dict()
        alphas = dict()
        
        for user_index in train_gs:
            personal_models[user_index] = deepcopy(global_model).to(args.device)
            personal_models[user_index].train()

            personal_predictors[user_index] = deepcopy(global_pred).to(args.device)
            personal_predictors[user_index].train()
            
            alphas[user_index] = args.personal_alpha
    
        personal_optimizers = dict()
        for user_index in train_gs:
            personal_optimizers[user_index] = torch.optim.Adam(itertools.chain(personal_models[user_index].parameters(), personal_predictors[user_index].parameters()), lr=args.learning_rate)

    for e in range(args.num_comms):
        for user_index in train_gs:
            if args.fedtype != 'lgfedavg' and args.fedtype!='lgfedavggate':
                models[user_index].load_state_dict(global_model.state_dict())
            if args.fedtype != 'fedrep' and args.fedtype != 'fedrepgate':
                predictors[user_index].load_state_dict(global_pred.state_dict()) 
            
        
        total_loss = dict()
        for loss_local_e in range(args.local_step):
            total_loss[loss_local_e] = 0
        
        for user_index in train_gs:
            for local_e in range(args.local_step):
                #record each local_e
                # forward
                if args.fedtype == 'fedrep' or args.fedtype == 'fedrepgate':
                    if (local_e < args.local_step - 1):
                        for name, param in predictors[user_index].named_parameters():
                                 param.requires_grad = True
                        for name, param in models[user_index].named_parameters():
                                 param.requires_grad = False
                    else:
                        for name, param in predictors[user_index].named_parameters():
                                 param.requires_grad = False
                        for name, param in models[user_index].named_parameters():
                                 param.requires_grad = True
                        

                
                
                h = models[user_index](train_gs[user_index], train_gs[user_index].ndata['feat'])
                pos_score = predictors[user_index](train_pos_gs[user_index], h)[list(range(len(train_pos_gs[user_index].edata['etype']))), train_pos_gs[user_index].edata['etype']]
            
                neg_score = predictors[user_index](train_neg_gs[user_index], h)[list(range(len(train_neg_gs[user_index].edata['etype']))), train_neg_gs[user_index].edata['etype']]

                loss = compute_loss(pos_score, neg_score)
                # backward
                optimizers[user_index].zero_grad()
                loss.backward() 
                
                
                total_loss[local_e] += loss.item() #local model loss, not personal
                
                if args.fedtype == 'fedgate':
                    for mp,mdp in zip(models[user_index].parameters(), models_delta[user_index].parameters()):
                        mp.grad.data.add_(-mdp.data)
                    for pp,pdp in zip(predictors[user_index].parameters(), predictors_delta[user_index].parameters()):
                        pp.grad.data.add_(-pdp.data) 
                if args.fedtype == 'fedrepgate':
                    if (local_e == args.local_step - 1):
                        for mp,mdp in zip(models[user_index].parameters(), models_delta[user_index].parameters()):
                            mp.grad.data.add_(-mdp.data)
                    #if (local_e < args.local_step - 1):
                    #    for pp,pdp in zip(predictors[user_index].parameters(), predictors_delta[user_index].parameters()):
                    #        pp.grad.data.add_(-pdp.data)
                if args.fedtype=='lgfedavggate':
                    #for mp,mdp in zip(models[user_index].parameters(), models_delta[user_index].parameters()):
                    #    mp.grad.data.add_(-mdp.data)
                    for pp,pdp in zip(predictors[user_index].parameters(), predictors_delta[user_index].parameters()):
                            pp.grad.data.add_(-pdp.data)
                        
                optimizers[user_index].step()

                #if e % 5 == 0:
                #    print('In epoch {}, loss: {}'.format(e, loss))
                
                
                if args.personaltype == "APFL":
                    # forward
                    optimizers[user_index].zero_grad()
                    personal_optimizers[user_index].zero_grad()
                    
                    #h = inference_personal(personal_models[user_index], models[user_index], alphas[user_index], train_gs[user_index], train_gs[user_index].ndata['feat'])
                    h = personal_models[user_index](train_gs[user_index], train_gs[user_index].ndata['feat'])


                    
                    pos_score = inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], 
                                                   train_pos_gs[user_index], h)[list(range(len(train_pos_gs[user_index].edata['etype']))), train_pos_gs[user_index].edata['etype']]
                    
                    neg_score = inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], 
                                                   train_neg_gs[user_index], h)[list(range(len(train_neg_gs[user_index].edata['etype']))), train_neg_gs[user_index].edata['etype']]
                    
                    personal_loss = compute_loss(pos_score, neg_score)

                    personal_loss.backward() 
                    personal_optimizers[user_index].step()
    
                    #alphas[user_index] = alpha_update(models[user_index], personal_models[user_index], alphas[user_index], args.learning_rate)
                    
                
            
        for loss_local_e in range(args.local_step):
            
            print('In epoch {}, local epoch {}, loss: {}'.format(e, loss_local_e, total_loss[loss_local_e] / len(train_gs)))
            tb_writer.add_scalar('Train/Loss',total_loss[loss_local_e] / len(train_gs), e * args.local_step + loss_local_e)
            
        if (e + 1) % 5 == 0:
        
            local_train_loss, local_train_AUC, local_train_MR, local_test_loss, local_test_AUC, local_test_MR = test_local_models(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, models, predictors)

            tb_writer.add_scalar('Local Train/Loss', local_train_loss, e)
            tb_writer.add_scalar('Local Train/AUC', local_train_AUC, e)
            tb_writer.add_scalar('Local Train/POS_MR', local_train_MR, e)

            tb_writer.add_scalar('Local Test/Loss', local_test_loss, e)
            tb_writer.add_scalar('Local Test/AUC', local_test_AUC, e)
            tb_writer.add_scalar('Local Test/POS_MR', local_test_MR, e)
        
        if args.personaltype == "APFL":
            personal_test_AUC, personal_test_MR = test_personal_models(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, models, predictors, personal_models, personal_predictors, alphas)
            tb_writer.add_scalar('Personal Test/AUC', personal_test_AUC, e)
            tb_writer.add_scalar('Personal Test/POS_MR', personal_test_MR, e)


        
        if args.fedtype == 'fedavg' or args.fedtype == 'fedrep' or args.fedtype == 'lgfedavg':
            global_model.load_state_dict(fedAvg(models, train_gs))

            global_pred.load_state_dict(fedAvg(predictors, train_gs))
        elif args.fedtype == 'fedgate':
            #global_model.load_state_dict(fedAvg(models, train_gs))
            gms, models_delta = fedGate(models, models_delta, train_gs, tau=args.local_step, lr=args.learning_rate)
            global_model.load_state_dict(gms)
            
            #global_pred.load_state_dict(fedAvg(predictors, train_gs))
            gps, predictors_delta = fedGate(predictors, predictors_delta, train_gs, tau=args.local_step, lr=args.learning_rate)
            global_pred.load_state_dict(gps)
        elif args.fedtype == 'fedrepgate':
            # global_model.load_state_dict(fedAvg(models, train_gs))
            gms, models_delta = fedGate(models, models_delta, train_gs, tau=args.local_step, lr=args.learning_rate)
            global_model.load_state_dict(gms)
            
            global_pred.load_state_dict(fedAvg(predictors, train_gs))
            #gps, predictors_delta = fedGate(predictors, predictors_delta, train_gs, tau=args.local_step, lr=args.learning_rate)
            #global_pred.load_state_dict(gps)
        elif args.fedtype=='lgfedavggate':
            #gms, models_delta = fedGate(models, models_delta, train_gs, tau=args.local_step, lr=args.learning_rate)
            #global_model.load_state_dict(gms)
            global_model.load_state_dict(fedAvg(models, train_gs))
            
            gps, predictors_delta = fedGate(predictors, predictors_delta, train_gs, tau=args.local_step, lr=args.learning_rate)
            global_pred.load_state_dict(gps)

        
        if (e + 1) % 5 == 0:
            global_train_loss, global_train_AUC, global_train_MR, global_test_loss, global_test_AUC, global_test_MR = test_global_model(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, global_model, global_pred)


            # ----------- 5. check results ------------------------ #
            tb_writer.add_scalar('Global Train/Loss', global_train_loss, e)
            tb_writer.add_scalar('Global Train/AUC', global_train_AUC, e)
            tb_writer.add_scalar('Global Train/POS_MR', global_train_MR, e)

            tb_writer.add_scalar('Global Test/Loss', global_test_loss, e)
            tb_writer.add_scalar('Global Test/AUC', global_test_AUC, e)
            tb_writer.add_scalar('Global Test/POS_MR', global_test_MR, e)
    
    torch.save(global_model.state_dict(), args.dataset + args.fedtype + "fed_model_" + args.model_type)
    torch.save(global_pred.state_dict(), args.dataset + args.fedtype + "fed_pred_" + args.model_type)
    
    for i in [5, 10, 15, 20, 25, 30, 35, 40]:
        hit_rate = get_hit_rate(train_gs, test_pos_gs, global_model, global_pred, all_devices, i, args.dataset)
        tb_writer.add_scalar('Global Test/Hit Rate', hit_rate, i)
        
    
    
    return

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset', default='wyze', type=str)
    parser.add_argument('-l', '--logdir', default='./runs', type=str)
    parser.add_argument('-lr', '--learning_rate', default=0.1, type=float)
    parser.add_argument('-i', '--local_step',default=1, type=int)
    parser.add_argument('-c', '--num_comms',default=100, type=int)
    parser.add_argument('-f', '--fedtype', default='fedavg', type=str) #fedgate
    parser.add_argument('-per', '--personaltype', default='None', type=str) #APFL
    parser.add_argument('-alpha', '--personal_alpha', default=0.1, type=float) #APFL
    parser.add_argument('-ce', '--centeralized', action='store_true')
    parser.add_argument('-neg', '--more_negative', action='store_true')
    parser.add_argument('-m', '--model_type', default='graphsage', type=str) #wait change
    
    
    
    seed = 0
    dgl.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)



    args = parser.parse_args()
    args.log_dir =  increment_dir(Path(args.logdir) / 'exp')
    args.log_dir += args.dataset + "_" + ('center' if args.centeralized else args.fedtype + 'local_step_' + str(args.local_step))
    os.makedirs(args.log_dir)
    yaml_file = str(Path(args.log_dir) / "args.yaml")
    with open(yaml_file, 'w') as out:
        yaml.dump(args.__dict__, out, default_flow_style=False)
    # Device configuration
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.centeralized:
        train_centered(args)
    else:
        train_federated(args)
