import torch
import qtoml as toml
import copy
import os

from .parsing import get_class

from ..training import optimizers
from ..training import scheduler as custom_lr_sched
from ..layers.model_config import ModelConfig, init_weights
from ..regimes.helpers import TrainConfig
from .. import regimes
from .. import models
from .. import datasets


def get_model(cfg, n_classes):
    # Load layer quantization
    if "model_cfg" in cfg:
        model_cfg = ModelConfig(**cfg["model_cfg"])
    else:
        model_cfg = ModelConfig()

    # Load model
    model = get_class(cfg["network"], models, model_cfg=model_cfg, n_classes=n_classes)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    return model


def get_dataset(cfg):
    return get_class(cfg["dataset"], datasets)


def get_optimizer(cfg_optim, model, **kwargs):
    # Load Optimizer
    use_layca = "layca" in cfg_optim["type"].lower()
    model_params = list(filter(  # Filter paramaters without gradients
        lambda npa: npa[1].requires_grad, model.named_parameters()
    ))
    opti_params = {**copy.deepcopy(cfg_optim), **kwargs}
    wd = cfg_optim.get("weight_decay", 0)
    if "weight_decay" in opti_params:
        del opti_params["weight_decay"]

    if use_layca:
        model_params = [
            {"params": list(
                p for n, p in model_params if n.endswith(".bias")
            ), "layca": False, "weight_decay": 0
            }, {"params": list(
                p for n, p in model_params if n.endswith(".weight") and "batchnorm" not in n
            ), "layca": True, "weight_decay": wd
            }, {"params": list(
                p for n, p in model_params if n.endswith(".weight") and "batchnorm" in n
            ), "layca": True, "weight_decay": 0
            }
        ]
    else:
        model_params = [
            {"params": list(
                p for n, p in model_params if n.endswith(".bias")
            ), "weight_decay": 0
            }, {"params": list(
                p for n, p in model_params if n.endswith(".weight")
            ), "weight_decay": wd
            }
        ]
    optimizer = get_class(opti_params, optimizers, params=model_params)
    return optimizer


def get_lr_sched(cfg, optimizer):
    # Load lr scheduler
    lr_sched = cfg["training"].get("lr_sched", None)
    if lr_sched is not None:
        default_lr_pack = torch.optim.lr_scheduler
        custom_lr_pack = custom_lr_sched
        lr_pack = (
            default_lr_pack
            if hasattr(default_lr_pack, lr_sched["type"]) else
            custom_lr_pack
        )
        lr_sched = get_class(
            lr_sched, lr_pack, optimizer=optimizer, n_epochs=cfg["training"]["params"]["n_epochs"]
        )
    else:
        class EmpytLrSched():
            def step(self):
                pass
        lr_sched = EmpytLrSched()
    return lr_sched


def execute_strategy(cfg, logdir=None):
    # -- Load data
    loaders, n_classes = get_dataset(cfg)

    # -- Make Model
    model = get_model(cfg, n_classes)

    # Initialize weights
    init_weights(model, cfg)

    # -- Make train config
    # Load other params
    params = copy.deepcopy(cfg["training"]["params"])
    if logdir is not None:
        params["logdir"] = logdir
    cfg_optim = cfg["training"]["optim"]
    params["get_optimizer"] = lambda model, cfg_optim=cfg_optim, **kwargs: get_optimizer(cfg_optim, model, **kwargs)
    params["get_lr_sched"] = lambda optimizer, cfg=cfg: get_lr_sched(cfg, optimizer)

    train_cfg = TrainConfig(**params)

    # -- Save config
    if logdir is not None:
        with open(os.path.join(logdir, "config.toml"), 'w') as cfg_fp:
            cfg_fp.write(toml.dumps(cfg))

    # -- Start training
    regime_cfg = cfg["training"].get("regime", {"type": "Normal"})
    trainer = get_class(regime_cfg, regimes, train_cfg, model, loaders)
    trainer.start()
