#import yaml
import argparse
import wandb
from dcg.Simulation import *
from dcg import Optimization
import pandas as pd
import os

from dcg.data import data_binary_mnist
import yaml

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader



project="DualClippedGossip"


def generate_topology(params, wandb_run = None, seed=0, verbose=True):
    
    if params["topology"]=="ErdosRenyi":
        topology = ErdosRenyi(
            params["nb_honest"], p=params["topology_hyper"], seed=seed)
    elif params["topology"]=="UltraTorus":
        topology = UltraTorus(
            params["nb_honest"], connective_range=params["topology_hyper"], opposed_neighbors=0)
    elif params["topology"]=="Clique":
        topology = Clique(params["nb_honest"])
    elif params["topology"]=="ErdosRenyi":
        topology = TorusWithPrincipal(params["nb_honest"])
    elif params["topology"]=="TwoWorlds":
        topology = TwoWorlds(params["nb_honest"], params["topology_hyper"])
    else:
        raise ValueError(str(params["topology"]) + "is not an appropriate topology")
    
    topology.add_byzantine_neighbors(params["nb_byzantine_neighbors"])
    if verbose:
        print(f"minimal number of honest neighbors : {topology.nb_honest_neighbors_min()} spectral gap : {topology.spectral_gap():.3}")


    eig_min_plus = np.linalg.eigvalsh(topology.laplacian_honest())[1]
    eig_max = np.linalg.eigvalsh(topology.laplacian_honest())[-1]
    breakdown_ratio = params["nb_byzantine_neighbors"]*2/eig_min_plus
    if verbose:
        print(f"min neighbors honest: {topology.nb_honest_neighbors_min()}",
              f"spectral gap : {topology.spectral_gap():.3}, smallest eig value: {eig_min_plus:.2},",
              f"breakdown ratio : {breakdown_ratio}")
    if wandb_run is not None:
        wandb_run.config["spectral gap"] = topology.spectral_gap()
        wandb_run.config["breakdown ratio"] = breakdown_ratio
    
    return topology, eig_max

def save_experiment(name, dict_outputs, folder=None):
    df = pd.DataFrame(dict_outputs)
    df.insert(0, "iteration", range(0, len(df)))
    if folder is None:
        folder = "results-data"
    os.makedirs(folder, exist_ok=True)

    file_path = os.path.join(folder, name+".csv")

    # Save to CSV
    df.to_csv(file_path, index=False)

    

def wandb_job(config=None):
    with wandb.init(
        config=config, project=project #, entity=entity,
        # name=run_name, id=run_id #, tags=tags
        ) as run:
        config = wandb.config
        params = config
        
        Simulation_task(
            params=params, wandb_run = run, verbose=True
            )


def main():
    with open("./config/unitary_test.yaml") as file:
        config = yaml.load(file, Loader = Loader)

        #tags= [str(key) + "_" + str(value) for key, value in config["parameters"].items()]
    
    wandb_job(config=config)


def Simulation_task(
        params, wandb_run = None, verbose=True, save_csv=True
        ):
    
    name_expe = f"{params['task']}_{params['communication_rule']}_{params['attack']}_{params['topology']}_{params['nb_honest']}h_{params['nb_byzantine_neighbors']}byz_{params['nb_iterations']}iter_{params["seed"]}"
    path =  os.path.join("results-data", name_expe+".csv")
    if os.path.isfile(path):
        print("Experiment exists already")
        return
    
    seed = params['seed']
    

    ######### topology ##################

    topology, _ = generate_topology(params, wandb_run=wandb_run, seed=seed, verbose=True)


    ###################### Optimizating task ###################
    if params["task"]=="regression":
        if params["model"]=="lin_reg":
            optim = Optimization.LinearRegressionClassif
        elif params["model"]=="logistic":
            optim = Optimization.LogisticRegression
        else:
            raise ValueError(f"{params['model']} and {params['data']} not compatible with linear regression")
        
        nb_honest, covariate_train_batches, target_train_batches, target_test, covariate_test =\
                data_binary_mnist(
                    nb_honest=params["nb_honest"]
                    ).values()
        
        if "relat_ridge_penality" in params:
            relat_ridge = params[ "relat_ridge_penality"]

        smoothness = 0
        strong_convexity = np.inf
        for i_honest in range(nb_honest):
            eig = np.linalg.eigvalsh(
                covariate_train_batches[i_honest].T @ covariate_train_batches[i_honest] / covariate_train_batches[i_honest].shape[0]
            )
            if params["model"]=="logistic":
                eig = eig/4
            
            smoothness = max(smoothness, eig[-1])
            strong_convexity = min(strong_convexity, eig[1])

        ridge_penalty = smoothness * relat_ridge
        smoothness = ridge_penalty + smoothness
        strong_convexity = max(strong_convexity + ridge_penalty, ridge_penalty)

        step_size_dual = topology.auto_step_size() * ridge_penalty 
        step_size_primal = 1/smoothness * params["relat_lr"]
        
        if verbose:
            print(f"smoothness = {smoothness:.3}, strong_convexity = {strong_convexity:.3},",
                f"penalization = {ridge_penalty:.3}, condition number:{smoothness/strong_convexity:.3}")
        
        print("defining optimization task....", end="")

        optimization_task = optim(
        n_honest=nb_honest, labels_train=target_train_batches, labels_test=target_test, 
        covariate_train=covariate_train_batches, covariate_test=covariate_test,
        learning_rate = step_size_primal, ridge_penalty=ridge_penalty
        )
        print("done!")
        
    elif params["task"]=="ACP":
        optim = Optimization.AverageConsensus

        nb_honest=params["nb_honest"]
        step_size_dual = topology.auto_step_size() 
        step_size_primal = None # should not be used

        x_honest_init = np.random.normal(0, 1, size=(nb_honest, 5))

        print("defining optimization task....", end="")
        optimization_task = optim(n_honest=nb_honest, x_honest_init=x_honest_init, learning_rate=step_size_primal)
        print("done!")
    

    ###################### Rule considered ####################
    
    step_size_communication = step_size_dual
    if params["communication_rule"]=="ClippedGossip":
        rule = CommunicationRule.LOCAL_CLIPPING_HE
        step_size_communication = topology.auto_step_size()
        algorithm_duality = AlgorithmDuality.PRIMAL
    elif params["communication_rule"]=="LocalClipping":
        rule = CommunicationRule.LOCAL_CLIPPING_OURS
        algorithm_duality = AlgorithmDuality.DUAL
    elif params["communication_rule"]=="LocalTrimming":
        rule = CommunicationRule.LOCAL_TRIMMING
        algorithm_duality =AlgorithmDuality.DUAL
    elif params["communication_rule"]=="GlobalClipping":
        rule = CommunicationRule.GLOBAL_CLIPPING
        algorithm_duality = AlgorithmDuality.DUAL
    elif params["communication_rule"]=="LocalClippingSym":
        rule = CommunicationRule.LOCAL_CLIPPING_SYM
        algorithm_duality = AlgorithmDuality.DUAL

    ###################### Attack considered ####################

    if params["attack"]=="FOE":
        attack = AttackType.FOE
    elif params["attack"]=="ALIE":
        attack = AttackType.ALIE
    elif params["attack"]=="SP Heterogeneity":
        attack=AttackType.DISSENSION_SPECTRAL
    elif params["attack"]=="No Attack":
        attack=AttackType.NO_ATTACK_ATTACK
    elif params["attack"]=='Dissension':
        attack=AttackType.DISSENSION_HE

    ############ SIMULATION  ###############

    if verbose:
            print("Current experiment ...", attack.value, rule.value)

    
    np.random.seed(seed)
    
    if "log_every_n_step" in params:
        log_every_n_step=params["log_every_n_step"]
    else:
        log_every_n_step=10
    if "optimizing_factor" in params:
        optimizing_factor=params['optimizing_factor']
    else:
        optimizing_factor=log_every_n_step
    
    simu = RobustDecentralizedOptimSimulation(
        topology=topology, communication_rule=rule, attack=attack,
        optimization_task=optimization_task, algorithm_duality=algorithm_duality,
        step_size_communication=step_size_communication, nb_iterations=params['nb_iterations'],
        wandb_run=wandb_run, log_every_n_step=log_every_n_step, optimizing_factor=optimizing_factor
    )
    res = simu.run()
    accuracy_data = res['accuracy']
    loss_train = res['loss_train']
    loss_test = res['loss_test']
    variance = res['variance']

    if verbose:
        print(f"{rule.value} Final accuracy: {accuracy_data[-1]:.3}",
              f"; Final train loss: {loss_train[-1]:.3}",
              f"; Final test loss: {loss_test[-1]:.3}")
    
    out = {
        "accuracy_data": accuracy_data,
        "loss_train":loss_train,
        "loss_test":loss_test,
        'variance':variance
    }
    if save_csv:
        save_experiment(
            name = name_expe,
            dict_outputs=out)
    
    return out


if True:# __name__=="main":
    main()
    
    
    

