import logging
import dataclasses
import pprint

import torch
from torch import optim
import torch.cuda
import torch.utils.data
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.collect_env import get_pretty_env_info
from torch.utils.tensorboard import SummaryWriter
from pyhocon import ConfigTree

from codebase.config import Args
from codebase.data import DATA
from codebase.models import MODEL
from codebase.optimizer import OPTIMIZER
from codebase.scheduler import SCHEDULER
from codebase.criterion import CRITERION
from codebase.engine import train_one_epoch, evaluate_one_epoch, evaluate_last_epoch
from codebase.resnet50 import *

from codebase.torchutils.common import set_cudnn_auto_tune, set_reproducible, generate_random_seed, disable_debug_api
from codebase.torchutils.common import set_proper_device, get_device
from codebase.torchutils.common import unwarp_module
from codebase.torchutils.common import compute_nparam, compute_flops
from codebase.torchutils.common import StateCheckPoint
from codebase.torchutils.common import MetricsStore
from codebase.torchutils.common import patch_download_in_cn
from codebase.torchutils.common import only_master
from codebase.torchutils.distributed import distributed_init, is_dist_avail_and_init, is_master, world_size
from codebase.torchutils.metrics import EstimatedTimeArrival
from codebase.torchutils.logging_ import init_logger, create_code_snapshot


_logger = logging.getLogger(__name__)


def excute_pipeline(
    only_evaluate: bool,
    start_epoch: int,
    max_epochs: int,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    model: nn.Module,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    scheduler: optim.lr_scheduler._LRScheduler,
    metric_store: MetricsStore,
    use_amp: bool,
    accmulated_steps: int,
    device: str,
    memory_format: str,
    log_interval: int,
    writer: SummaryWriter,
    state_ckpt: StateCheckPoint,
    states: dict,
    landa: float,
    loss_normalizer_h: list,
    output_dir: str
):

    scores_optimizer = optim.Adam(model.parameters(), lr=0.000001)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if only_evaluate:
        metric_store += evaluate_one_epoch(
            is_last=False,
            epoch=0,
            model=model,
            loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            use_amp=use_amp,
            accmulated_steps=accmulated_steps,
            device=device,
            memory_format=memory_format,
            log_interval=log_interval,
            is_score_training=False,
            landa=landa,
            loss_normalizer_h=loss_normalizer_h,
            output_dir=output_dir
        )
        return

    eta = EstimatedTimeArrival(max_epochs)

    for epoch in range(start_epoch+1, max_epochs+1):
        is_last = True if epoch == max_epochs else False
        
        if is_dist_avail_and_init():
            if hasattr(train_loader, "sampler"):
                train_loader.sampler.set_epoch(epoch)
                val_loader.sampler.set_epoch(epoch)

        if epoch % 9 in [0, 1, 2]:
            metric_store += train_one_epoch(
                is_last=is_last,
                epoch=epoch,
                model=model,
                loader=train_loader,
                criterion=criterion,
                optimizer=scores_optimizer,
                scheduler=scheduler,
                use_amp=use_amp,
                accmulated_steps=accmulated_steps,
                device=device,
                memory_format=memory_format,
                log_interval=log_interval,
                is_score_training=True,
                landa=landa,
                loss_normalizer_h=loss_normalizer_h,
                output_dir=output_dir
            )
        else:
            metric_store += train_one_epoch(
                is_last=is_last,
                epoch=epoch,
                model=model,
                loader=train_loader,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                use_amp=use_amp,
                accmulated_steps=accmulated_steps,
                device=device,
                memory_format=memory_format,
                log_interval=log_interval,
                is_score_training=False,
                landa=landa,
                loss_normalizer_h=loss_normalizer_h,
                output_dir=output_dir
            )

        metric_store += evaluate_one_epoch(
            is_last=is_last,
            epoch=epoch,
            model=model,
            loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            use_amp=use_amp,
            accmulated_steps=accmulated_steps,
            device=device,
            memory_format=memory_format,
            log_interval=log_interval,
            is_score_training=False,
            landa=landa,
            loss_normalizer_h=loss_normalizer_h,
            output_dir=output_dir
        )

        for k, v in metric_store.get_last_metrics().items():
            writer.add_scalar(k, v, epoch)

        state_ckpt.save(metric_store=metric_store, states=states, epoch=epoch)

        eta.step()

        best_metrics = metric_store.get_best_metrics()
        _logger.info(f"Epoch={epoch:04d} complete, best val top1-acc={best_metrics['eval/top1_acc']*100:.2f}%, "
                     f"top5-acc={best_metrics['eval/top5_acc']*100:.2f}% (epoch={metric_store.best_epoch+1}), {eta}")



def prepare_for_training(conf: ConfigTree, output_dir: str, local_rank: int):
    # model_config = conf.get("model")
    # load_from = model_config.pop("load_from")
    # model: nn.Module = MODEL.build_from(model_config)
    # if load_from is not None:
    #     model.load_state_dict(torch.load(conf.get("model.load_from"), map_location="cpu"))

    model = resnet50(pretrained=True)

    if is_dist_avail_and_init() and conf.get_bool("sync_batchnorm"):
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    image_size = conf.get_int('data.image_size')
    _logger.info(f"Model details: n_params={compute_nparam(model)/1e6:.2f}M, "
                 f"flops={compute_flops(model,(1,3, image_size, image_size))/1e6:.2f}M.")

    train_loader, val_loader = DATA.build_from(conf.get("data"), dict(local_rank=local_rank))

    criterion = CRITERION.build_from(conf.get("criterion"))

    optimizer_config: dict = conf.get("optimizer")
    basic_bs = optimizer_config.pop("basic_bs")
    optimizer_config["lr"] = optimizer_config["lr"] * (conf.get("data.batch_size") * world_size() / basic_bs)
    optimizer = OPTIMIZER.build_from(optimizer_config, dict(params=model.named_parameters()))
    _logger.info(f'Set lr={optimizer_config["lr"]:.4f} with batch size={conf.get("data.batch_size") * world_size()}')

    scheduler = SCHEDULER.build_from(conf.get("scheduler"), dict(optimizer=optimizer))

    if torch.cuda.is_available():
        model = model.to(device=get_device(), memory_format=getattr(torch, conf.get("memory_format")))
        criterion = criterion.to(device=get_device())

    writer = only_master(SummaryWriter(output_dir))

    metric_store = MetricsStore(dominant_metric_name="eval/top1_acc")
    states = dict(model=unwarp_module(model), optimizer=optimizer, scheduler=scheduler)
    state_ckpt = StateCheckPoint(output_dir)

    state_ckpt.restore(metric_store, states, device=get_device())

    if is_dist_avail_and_init():
        model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)

    return model, train_loader, val_loader, criterion, optimizer, scheduler, \
        state_ckpt, writer, metric_store, states


def _init(local_rank: int, ngpus_per_node: int, args: Args):
    set_proper_device(local_rank)
    rank = args.node_rank*ngpus_per_node+local_rank
    init_logger(rank=rank, filenmae=args.output_dir/"default.log")

    patch_download_in_cn()

    if StateCheckPoint(args.output_dir).is_ckpt_exists():
        _logger.info("-"*30+"Resume from the last training checkpoints."+"-"*30)

    if set_reproducible:
        set_reproducible(generate_random_seed())
    else:
        set_cudnn_auto_tune()
        disable_debug_api()

    create_code_snapshot(name="code", include_suffix=[".py", ".conf"],
                         source_directory=".", store_directory=args.output_dir)

    _logger.info("Collect envs from system:\n" + get_pretty_env_info())
    _logger.info("Args:\n" + pprint.pformat(dataclasses.asdict(args)))

    distributed_init(dist_backend=args.dist_backend, init_method=args.dist_url,
                     world_size=args.world_size, rank=rank)


def main_worker(local_rank: int,
                ngpus_per_node: int,
                args: Args,
                conf: ConfigTree):

    _init(local_rank=local_rank, ngpus_per_node=ngpus_per_node, args=args)

    model, train_loader, val_loader, criterion, optimizer, \
        scheduler, saver, writer, metric_store, states = \
        prepare_for_training(conf, args.output_dir, local_rank)

    loss_normalizer_hs = [[7, 5, 3, 0],
                          [7, 6, 5, 0],
                          [8, 6, 4, 0],
                          [9, 6, 3, 0],
                          [9, 8, 5, 0]]
    loss_normalizer_h = loss_normalizer_hs[conf.get_int("loss_normalizer_h")]

    excute_pipeline(
        only_evaluate=conf.get_bool("only_evaluate"),
        start_epoch=metric_store.total_epoch,
        max_epochs=conf.get_int("max_epochs"),
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        metric_store=metric_store,
        use_amp=conf.get_bool("use_amp"),
        accmulated_steps=conf.get_int("accmulated_steps"),
        device=get_device(),
        memory_format=getattr(torch, conf.get("memory_format")),
        log_interval=conf.get_int("log_interval"),
        writer=writer,
        state_ckpt=saver,
        states=states,
        landa=conf.get_float("landa"),
        loss_normalizer_h=loss_normalizer_h,
        output_dir=args.output_dir,
    )




def main(args: Args):
    distributed = args.world_size > 1
    ngpus_per_node = torch.cuda.device_count()
    if distributed:
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, args.conf))
    else:
        local_rank = 0
        main_worker(local_rank, ngpus_per_node, args, args.conf)
