# Saving the model
import json
from os import mkdir

import torch

from src.utils import config


def save_config_to_txt(filename='config.txt'):

    config_vars = {key: value for key, value in config.__dict__.items() if not key.startswith("__")}

    with open(filename, 'w') as f:
        for key, value in config_vars.items():
            if isinstance(value, list):
                value = str(value)
            f.write(f'{key}: {value}\n')

def save_model(model, path):
    """
    Save FedSage_Plus model state dict

    Args:
        model: FedSage_Plus model instance
        path: path to save the model
    """
    state_dict = {
        'encoder_model': model.encoder_model.state_dict(),
        'reg_model': model.reg_model.state_dict(),
        'gen': model.gen.state_dict(),
        'mend_graph': model.mend_graph.state_dict(),
        'classifier': model.classifier.state_dict()
    }
    torch.save(state_dict, path)


# Loading the model
def load_model(model, path):
    """
    Load FedSage_Plus model from saved state dict

    Args:
        model: FedSage_Plus model instance
        path: path to the saved model
    """
    state_dict = torch.load(path)

    model.encoder_model.load_state_dict(state_dict['encoder_model'])
    model.reg_model.load_state_dict(state_dict['reg_model'])
    model.gen.load_state_dict(state_dict['gen'])
    model.mend_graph.load_state_dict(state_dict['mend_graph'])
    model.classifier.load_state_dict(state_dict['classifier'])

    # Restore requires_grad settings
    model.encoder_model.requires_grad_(True)
    model.reg_model.requires_grad_(True)
    model.mend_graph.requires_grad_(True)
    model.classifier.requires_grad_(False)

    return model


def save_experiment_results(config, client_results, output_path):
    # Prepare the data structure
    data = {
        "config": {
            "dataset": config.dataset,
            "num_attacker": config.num_attacker,
            "attack_intensity": config.attack_intensity,
            "seed": config.seed
        },
        "clients": []
    }

    if config.plus:

        # Add results for each client
        for client in client_results:
            client_data = {
                "client_id": client["client_id"],
                "test_acc": client["test_acc"],
                "added_nodes": client["added_all"],
                "avg_added": client["added_avg"]
            }
            data["clients"].append(client_data)

        num_benign_client = len(client_results) - 1
        # Calculate averages
        avg_test_acc = sum(c["test_acc"] for c in client_results[1:]) / num_benign_client
        avg_added_nodes = sum(c["added_all"] for c in client_results[1:]) / num_benign_client
        avg_added_avg = sum(c["added_avg"] for c in client_results[1:]) / num_benign_client

        # Add averages to the data structure
        data["averages"] = {
            "avg_test_acc": avg_test_acc,
            "avg_added_nodes": avg_added_nodes,
            "avg_added": avg_added_avg
        }
    else:
        # Add results for each client
        for client in client_results:
            client_data = {
                "client_id": client["client_id"],
                "test_acc": client["test_acc"]
            }
            data["clients"].append(client_data)

        num_benign_client = len(client_results) - 1
        # Calculate averages
        avg_test_acc = sum(c["test_acc"] for c in client_results[1:]) / num_benign_client if config.num_owners>1 else 0

        # Add averages to the data structure
        data["averages"] = {
            "avg_test_acc": avg_test_acc
        }

    # Save to JSON file
    with open(output_path, 'w') as f:
        json.dump(data, f, indent=2)