import os
import logging
from math import inf
from collections import OrderedDict
import argparse
from argparse import Namespace
import pickle
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import wandb

from util import LoadDataset, LoadModel, generate_filename, set_all_seed, generate_cycle, generate_graph
from server import get_lambda, get_K
logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s -%(module)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S %p",
    level=10,
)

logger = logging.getLogger(__name__)
 
def main(args):
    _config = Namespace(
        project_name=" ",
        seed=args.seed,
        dataset=args.dataset,
        method=args.method,

        p=args.p,
        rho=args.rho,
        base_lr=args.base_lr,
        is_convex=args.is_convex,
    )

    log_path = os.path.join(args.log_dir, args.dataset, args.method)
    os.makedirs(log_path, exist_ok=True)
    writer = SummaryWriter(log_path)

    run_name = str(args.method) + "_" + str(args.dataset)
    wandb.init(project=_config.project_name, config=_config.__dict__, name=run_name)

    dataset_name = args.dataset
    data, label = LoadDataset(dataset_name, device=args.device)
    # label=2*(label==1)-1

    log_freq = args.log_freq

    # gossip_matrix = generate_cycle(args.n_nodes, args.num_edge)
    gossip_matrix = generate_graph(args.n_nodes, args.p)
    gm = torch.from_numpy(gossip_matrix)
    gm_lambda = get_lambda(gm)
    server = LoadModel(
        method=args.method,
        x=data,
        y=label,
        scale=args.scale,
        n_nodes=args.n_nodes,
        gossip_matrix=gossip_matrix,
        rho=args.rho,
        batch_size=args.batch_size,
        base_lr=args.base_lr,
        is_convex=args.is_convex,
    )
    #K=1 or K=server.get_K()...
    

    losses = [server.compute_loss().item()]
    tot_steps = [0]
    tot_consensus_time=[0]
    logger.info(f"the initial loss is {losses[0]}")
    writer.add_scalar("loss along with querying", losses[0], tot_steps[-1])
    wandb.log(
        {
            "queries": tot_steps[-1],
            "train_loss": losses[0],
            "gap": np.log10(losses[0] - OPTIMAL_rcv1_convex + 1e-16),
            "consensus_time": tot_consensus_time[-1],
        }
    )
    min_loss = inf

    epoch = 0
    while tot_steps[-1] <= args.ifo_orcal_num:
        server.step()
        if args.method == "DsgFW" or args.method == "DeFW":
            tot_consensus_time.append(2*epoch)
        if args.method == "DVRGTFW":
            tot_consensus_time.append(2*K*epoch)
        if epoch % log_freq == 0:
            average_x = server.get_average_x()
            param_norm = average_x.abs().sum()
            
            if param_norm <= args.scale:
                loss = server.compute_loss().item()
                min_loss = min(loss, min_loss)

                losses.append(loss)
                tot_steps.append(server.ifo_num)

                logger.info(
                    f"At step {epoch}, already querying {tot_steps[-1]} ISO, the loss is {loss}, the gap is {np.log10(np.abs(losses[-1] - OPTIMAL_rcv1_convex))}."
                )
                logger.info(
                    f"At step {epoch}, the lambda is {gm_lambda}."
                )
                logger.info(
                    f"At step {epoch}, already  {tot_consensus_time[-1]} consensus times, the loss is {loss}, the gap is {np.log10(np.abs(losses[-1] - OPTIMAL_rcv1_convex))}."
                )

                wandb.log(
                    {
                        "queries": tot_steps[-1],
                        "train_loss": loss,
                        "gap": np.log10(np.abs(loss - OPTIMAL_rcv1_convex)),
                        "consensus_time": tot_consensus_time[-1]
                    }
                )

                if loss < OPTIMAL_rcv1_convex:
                    logging.warning(f'WTF, are you serious!')

                writer.add_scalar("loss along with querying", loss, server.ifo_num)
                

        epoch += 1

    writer.add_hparams(
        vars(_config),
        {"min_loss": min_loss},
        run_name=os.path.dirname(os.path.realpath(__file__)) + os.sep + log_path,
    )
    writer.close()

    losses_file, steps_file, consensus_time_file = generate_filename(vars(_config))
    losses_path, steps_path, consensus_time_path= (
        os.path.join(log_path, losses_file),
        os.path.join(log_path, steps_file),
        os.path.join(log_path, consensus_time_file),
    )
    return losses, tot_steps, tot_consensus_time, losses_path, steps_path, consensus_time_path,  server


def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset", type=str, required=True, choices=["rcv1", "gisette", "real-sim", "covtype.libsvm.binary.scale", "a9a"]
    )
    parser.add_argument("--method", type=str, required=True, choices=["DeFW", "DsgFW", "DVRGTFW"])

    # the proximal function parameteres, default l1 norm.
    parser.add_argument("--scale", type=float)
    parser.add_argument("--max_T", type=int, default=500)
    parser.add_argument("--n_nodes", type=int, default=100)

    # the batch information for ZDVRFW method
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--rho", type=float, default=0.001)
    parser.add_argument("--base_lr", type=float, default=1.0)
    parser.add_argument("--is_convex", action='store_false')  # default for logitstic regression
    parser.add_argument("--p", type=float, default=0.1)

    parser.add_argument("--log_dir", type=str, default="./logs")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--log_freq", type=int, default=5)
    parser.add_argument("--K", type=int, default=10 )
    parser.add_argument("--ifo_orcal_num", type=float, default=10000000)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse()
    seed = args.seed

    set_all_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    losses_dict = OrderedDict()
    steps_dict = OrderedDict()
    consensus_times_dict = OrderedDict()
    method_loss, method_step, method_consensus_time, losses_path, steps_path, consensus_time_path,  server = main(args)

    with open(losses_path, "wb") as f:
        pickle.dump(method_loss, f)
    with open(steps_path, "wb") as f:
        pickle.dump(method_step, f)
    with open(consensus_time_path, "wb") as f:
        pickle.dump(method_consensus_time, f)
    
