import os
import random
import subprocess
from datetime import datetime

import git

# import hydra
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.utils.data import DataLoader, Subset, TensorDataset

import device
import sample
import wandb
from logger import log_settings, log_state_dicts, log_artifacts
from utils import convert_run_name
from tqdm import tqdm
import copy
from matplotlib import pyplot as plt
import argparse


# @hydra.main(config_path="conf", config_name="config", version_base=None)
def _generate_cfg(config: DictConfig):
    """Generate cfg from ./conf/config.yaml

    Args:
        cfg (DictConfig): a DictConfig object from ./conf/config.yaml
    """
    global cfg
    cfg = config


def _fix_seed(seed: int):
    """Fix the seed of the experiment and runs

    Args:
        seed (int): Random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def _disc_train(device, sampler, device_name, step):
    loss_disc_train = device.disc_train(device.dset)
    # acc_disc_val, loss_disc_val = device.disc_val(device.dset)
    # acc_disc_test, loss_disc_test = device.disc_val(sampler.dset_test)
    # wandb.log(
    #     {
    #         f"disc_train_loss_{device_name}": loss_disc_train,
    #         f"disc_val_acc_{device_name}": acc_disc_val,
    #         f"disc_val_loss_{device_name}": loss_disc_val,
    #         f"disc_test_acc_{device_name}": acc_disc_test,
    #         f"disc_test_loss_{device_name}": loss_disc_test,
    #     },
    #     step=step,
    # )


def objective(eval=False):
    _fix_seed(cfg.simul.seed)

    # init sampler and devices
    sampler = sample.Sampler(cfg)
    if cfg.fl.central:
        devices_c = [device.TrainableDevice(cfg=cfg, dset=sampler.dset_train_c)]
        cfg.fl.n_c = 1
    else:
        devices_c = [
            device.TrainableDevice(cfg=cfg, dset=sampler.dset_chunks[i])
            for i in range(cfg.fl.n_c)
        ]
    for c in range(cfg.env.byzantine.n_c):
        devices_c[c].dl = devices_c[c].create_noisy_dl(cfg.env.byzantine.ratio_c_noise)
    device_s = device.TrainableServer(cfg=cfg, dset=sampler.dset_chunk_s)

    # val_dl = DataLoader(device_s.dset, batch_size=cfg.c.bs, shuffle=False)

    # init wandb
    with open_dict(cfg):
        repo = git.Repo(search_parent_directories=True)
        cfg.git.hash = repo.head.object.hexsha
        cfg.git.branch = repo.active_branch.name
        cfg.git.message = repo.head.object.message
    tags = [
        f"alpha{cfg.fl.diri_alpha}",
        cfg.fl.combine,
        f"c_{cfg.dset.name}",
        f"s_{cfg.dset_s.name}",
        f"s{cfg.simul.seed}",
        f"n{cfg.env.ratio_c_noise}",
        os.environ["ip"],
        f"byz{cfg.env.byzantine.n_c}",
        f"n_c{cfg.fl.n_c}",
        f"nsr{cfg.fl.n_s_ratio}",
        f"ncr{cfg.fl.n_c_ratio}",
    ]
    if cfg.wandb.convert_run_name:
        tags.append(cfg.wandb.run_name)
    if eval == True:
        device_s.net.load_state_dict(
            torch.load(
                f"model/{cfg.dset.name}_{cfg.dset_s.name}_{cfg.fl.combine}_{cfg.fl.diri_alpha}_{cfg.simul.seed}.pt"
            )
        )
        acc_test_s = device_s.val(sampler.dset_test)
        print("seed: ", cfg.simul.seed, ", test acc: ", acc_test_s)
        return acc_test_s

    # wandb.init(
    #     project="FedEDG",
    #     group=cfg.wandb.ex_name,
    #     job_type=cfg.wandb.job_type,
    #     name=cfg.wandb.run_name + f"s{cfg.simul.seed}",
    #     dir="..",
    #     tags=tags,
    #     config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    # )
    # wandb.run.name += "_" + datetime.fromtimestamp(wandb.run.start_time).strftime("%m%d%H%M")
    # wandb.run.log_code(
    #     ".",
    #     include_fn=lambda path: path.endswith(".py") or path.endswith("conf/config.yaml"),
    # )
    # log_settings(cfg, sampler)
    # log_artifacts(sampler.dset_chunks, 'dsets_c')
    # log_artifacts(sampler.dset_test, 'dset_test')
    # log_artifacts(sampler.dset_chunk_s, 'dset_s')

    participation_history = []

    if cfg.gan.timing:
        if cfg.gan.model_load:
            device_s.gan_gen.load_state_dict(
                torch.load(
                    "model/"
                    + cfg.dset.name
                    + "_"
                    # + cfg.fl.n_c_ratio
                    # + "_"
                    + cfg.dset_s.name
                    # + "_"
                    # + cfg.fl.n_s_ratio
                    + "_diri"
                    + str(cfg.fl.diri_alpha)
                    + "_gan_gen_"
                    + str(cfg.gan.load_ep)
                    + "_s"
                    + str(cfg.simul.seed)
                    + ".pt"
                )
            )
        else:
            device_s.gen_train(device_s.dset)
        for device_c in devices_c:
            device_c.gan_gen.load_state_dict(device_s.gan_gen.state_dict())
            device_c.disc_net.load_state_dict(device_s.disc_net.state_dict())

    if cfg.gan.timing == "pre":
        pbar = tqdm(range(cfg.gan.ep))
        if cfg.gan.c.disc_load:
            for c, device_c in enumerate(devices_c):
                device_c.disc_net.load_state_dict(
                    torch.load(
                        f"model/disc_net_{c}_ep_{cfg.gan.c.load_ep-1}_{cfg.gan.c.model}_{cfg.dset.name}_{cfg.fl.n_c}_s{cfg.simul.seed}_{cfg.fl.diri_alpha}_{cfg.gan.load_ep}.pt"
                    )
                )
        else:
            for round_ in pbar:
                for c, device_c in enumerate(devices_c):
                    _disc_train(device_c, sampler, f"c{c}", round_)
                    if (round_ + 1) % 10 == 0:
                        torch.save(
                            device_c.disc_net.state_dict(),
                            f"model/disc_net_{c}_ep_{round_}_{cfg.gan.c.model}_{cfg.dset.name}_{cfg.fl.n_c}_s{cfg.simul.seed}_{cfg.fl.diri_alpha}_{cfg.gan.load_ep}.pt",
                        )

    pbar = tqdm(range(cfg.simul.max_round))

    for round_ in pbar:
        # train clients
        if cfg.fl.central:
            sampled_devices_c = [(0, devices_c[0])]
        else:
            sampled_devices_c = random.sample(
                list(enumerate(devices_c)),
                int(len(devices_c) * cfg.fl.client_sample_ratio),
            )
        participation_history.append(
            np.array(sorted([c for c, _ in sampled_devices_c]), dtype=np.int64)
        )
        for c, device_c in sampled_devices_c:
            if cfg.gan.c.disc_cls_pretrain and round_ == 0:
                pass
            elif cfg.gkd.is_gkd and round_ != 0 and not cfg.gkd.avg_first:
                device_c.update(device_s.global_model.state_dict())
            else:
                device_c.update(device_s.net.state_dict())
            if cfg.simul.gpu_on_off:
                device_c.net.to(device_c.gpu)
            loss_train = device_c.train()
            acc_train = device_c.val(device_c.dset)
            acc_test = device_c.val(sampler.dset_test)
            # wandb.log(
            #     {
            #         f"disc_val_loss_{c}": loss_disc,
            #         f"disc_val_acc_{c}": acc_disc,
            #     },
            #     step=round_,
            # )
            # wandb.log(
            #     {
            #         f"train_loss_c{c}": loss_train,
            #         f"train_acc_c{c}": acc_train,
            #         f"test_acc_c{c}": acc_test,
            #     },
            #     step=round_,
            # )

            if cfg.simul.gpu_on_off:
                device_c.net.to("cpu")
            torch.cuda.empty_cache()

        # update server model
        if cfg.simul.gpu_on_off:
            device_s.net.to(device_s.gpu)
        device_s.combine(dict(sampled_devices_c))
        acc_test_s = device_s.val(sampler.dset_test)
        if cfg.fl.combine != "avg":
            device_s.val_em(sampler.dset_test, devices_c)
        # wandb.log(
        #     {
        #         "test_acc_s": acc_test_s,
        #         "epoch": (round_ + 1) * cfg.c.ep,
        #         "round": round_,
        #     },
        #     step=round_,
        # )
        # print(f'test_acc_s at r{round_}', acc_test_s)
        pbar.set_postfix({"test_acc": acc_test_s})

        # log model state dicts
        if (round_ + 1) % cfg.log.model_save_step == 0:
            log_state_dicts(device_s, "s", round_)
        if (round_ + 1) % cfg.log.client_model_save_step == 0:
            for c, device_c in enumerate(devices_c):
                log_state_dicts(device_c, f"c{c}", round_)

        pbar.set_postfix({"round": round_, "test_acc": acc_test_s})

        if cfg.simul.gpu_on_off:
            device_s.net.to("cpu")
        torch.cuda.empty_cache()
    # wandb.alert(
    #     title=f"end {cfg.simul.seed}",
    #     text=f"{cfg.wandb.ex_name}_{cfg.wandb.run_name} acc {acc_test_s}",
    #     level="INFO",
    # )
    # wandb.finish()
    print("seed: ", cfg.simul.seed, ", test acc: ", acc_test_s)
    # save model
    torch.save(
        device_s.net.state_dict(),
        f"model/{cfg.dset.name}_{cfg.dset_s.name}_{cfg.fl.combine}_{cfg.fl.diri_alpha}_{cfg.simul.seed}.pt",
    )
    return acc_test_s


if __name__ == "__main__":
    """Fix gpu and seed. Start experiment by calling objective"""

    import yaml

    with open("conf/config.yaml") as f:
        global cfg
        cfg = OmegaConf.create(yaml.safe_load(f))
        # print(cfg)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dset_c",
        type=str,
        default="cifar10",
        choices=["cifar10", "cifar100, imagenet100"],
    )
    parser.add_argument(
        "--dset_s",
        dest="dset_s",
        type=str,
        default="cifar10",
        choices=["cifar10", "cifar100, imagenet100"],
    )
    parser.add_argument(
        "--combine",
        type=str,
        dest="combine",
        default="gan",
        choices=[
            "avg",
            "em_entropy_soft",
            "df",
            "gan",
            "logit_var",
            "df_gkd",
            "gan_dafkd",
        ],
    )
    parser.add_argument("--diri_alpha", dest="diri_alpha", type=float, default=0.1)
    parser.add_argument("--anneal", dest="anneal", type=bool, default=True)
    parser.add_argument("--gen_load", dest="gen_load", type=bool, default=True)
    parser.add_argument("--diff_disc_ep", dest="diff_disc_ep", type=bool, default=False)
    parser.add_argument("--disc_ep", dest="disc_ep", type=int, default=30)
    parser.add_argument("--eval", dest="eval", type=bool, default=False)

    args = parser.parse_args()
    # _generate_cfg()
    cfg.dset.name = args.dset_c
    cfg.dset_s.name = args.dset_s
    cfg.fl.combine = args.combine
    cfg.fl.diri_alpha = args.diri_alpha
    cfg.s.anneal = args.anneal
    cfg.gan.model_load = args.gen_load

    cfg.dset.path = f"data/{args.dset_c}"
    cfg.dset_s.path = f"data/{args.dset_s}"

    if args.dset_c in ["cifar100", "imagenet100"]:
        cfg.dset.n_cls = 100
        cfg.model.model = "cifar100resnet18"
    elif args.dset_c == "cifar10":
        cfg.dset.n_cls = 10
        cfg.model.model = "resnet18"

    if args.dset_c == "imagenet100":
        cfg.c.ep = 10
    else:
        cfg.c.ep = 30

    if args.dset_s == "imagenet100":
        cfg.s.ep = 3
        cfg.gan.ep = 10
        cfg.gan.c.load_ep = 10
    else:
        cfg.s.ep = 10
        cfg.gan.ep = 30
        cfg.gan.c.load_ep = 30

    if args.dset_s in ["cifar100", "imagenet100"]:
        cfg.dset_s.n_cls = 100
    else:
        cfg.dset_s.n_cls = 10
    if cfg.fl.combine == "df_gkd":
        cfg.gkd.is_gkd = True

    if cfg.fl.combine[:3] == "gan":
        cfg.gan.timing = "pre"

    if args.diff_disc_ep:
        cfg.gan.ep = args.disc_ep
        cfg.gan.c.load_ep = args.disc_ep

    if cfg.wandb.convert_run_name:
        cfg.wandb.run_name = convert_run_name(cfg)
    if cfg.sample_method == "iid":
        cfg.fl.diri_alpha = 1e5
    if cfg.simul.gpu is None:
        # If gpu number is not given, then use the tmux window number
        with open_dict(cfg):
            cfg.simul.gpu = int(subprocess.getoutput("tmux display-message -p '#I'"))
    for seed in cfg.simul.seeds:
        cfg.simul.seed = seed
        # for c_lr in cfg.c.lrs:
        # for s_ep in cfg.s.eps:
        #     for c_ep in cfg.c.eps:
        #         # for s_lr in cfg.s.lrs:
        #         if s_ep == 10 and c_ep == 10:
        #             continue
        #         # cfg.c.lr = c_lr
        #         # cfg.s.lr = s_lr
        #         cfg.c.ep = c_ep
        #         cfg.s.ep = s_ep
        #         _fix_seed(cfg.simul.seed)
        #         print(cfg.wandb.ex_name)
        #         print(cfg.wandb.run_name + f"s{cfg.simul.seed}")
        #         objective()
        _fix_seed(cfg.simul.seed)
        print(cfg.wandb.ex_name)
        print(cfg.wandb.run_name + f"s{cfg.simul.seed}")
        objective(eval=args.eval)

#
