import time
from copy import deepcopy
from functools import wraps
from typing import Callable, Dict, Optional, Tuple

import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP

from . import logger as log
from . import utils
from .logger import TrainingMetrics, ValidationMetrics
from .models.common import EMA

from lbmqt.utils import GOR, get_memory_usage, compute_tensor_bytes, exp_recorder
from lbmqt import config


MB = 1024**2
GB = 1024**3

# class Executor:
#     def __init__(
#         self,
#         model: nn.Module,
#         loss: Optional[nn.Module],
#         cuda: bool = True,
#         memory_format: torch.memory_format = torch.contiguous_format,
#         amp: bool = False,
#         scaler: Optional[torch.cuda.amp.GradScaler] = None,
#         divide_loss: int = 1,
#         ts_script: bool = False,
#     ):
#         assert not (amp and scaler is None), "Gradient Scaler is needed for AMP"

#         def xform(m: nn.Module) -> nn.Module:
#             if cuda:
#                 m = m.cuda()
#             m.to(memory_format=memory_format)
#             return m

#         self.model = xform(model)
#         if ts_script:
#             self.model = torch.jit.script(self.model)
#         self.ts_script = ts_script
#         self.loss = xform(loss) if loss is not None else None
#         self.amp = amp
#         self.scaler = scaler
#         self.is_distributed = False
#         self.divide_loss = divide_loss
#         self._fwd_bwd = None
#         self._forward = None

#     def distributed(self, gpu_id):
#         self.is_distributed = True
#         s = torch.cuda.Stream()
#         s.wait_stream(torch.cuda.current_stream())
#         with torch.cuda.stream(s):
#             self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
#         torch.cuda.current_stream().wait_stream(s)

#     def _fwd_bwd_fn(
#         self,
#         input: torch.Tensor,
#         target: torch.Tensor,
#     ) -> torch.Tensor:
#         with autocast(enabled=self.amp):
#             loss = self.loss(self.model(input), target)
#             loss /= self.divide_loss

#         self.scaler.scale(loss).backward()
#         return loss

#     def _forward_fn(
#         self, input: torch.Tensor, target: torch.Tensor
#     ) -> Tuple[torch.Tensor, torch.Tensor]:
#         with torch.no_grad(), autocast(enabled=self.amp):
#             output = self.model(input)
#             loss = None if self.loss is None else self.loss(output, target)

#         return output if loss is None else loss, output

#     def optimize(self, fn):
#         return fn

#     @property
#     def forward_backward(self):
#         if self._fwd_bwd is None:
#             if self.loss is None:
#                 raise NotImplementedError(
#                     "Loss must not be None for forward+backward step"
#                 )
#             self._fwd_bwd = self.optimize(self._fwd_bwd_fn)
#         return self._fwd_bwd

#     @property
#     def forward(self):
#         if self._forward is None:
#             self._forward = self.optimize(self._forward_fn)
#         return self._forward

#     def train(self):
#         self.model.train()
#         if self.loss is not None:
#             self.loss.train()

#     def eval(self):
#         self.model.eval()
#         if self.loss is not None:
#             self.loss.eval()


class LbmqtTrainer:
    def __init__(
        self,
        model,
        optimizer,
        loss_fn,
        cuda: bool = True,
        memory_format: torch.memory_format = torch.contiguous_format,
        amp: bool = False,
        scaler: Optional[torch.cuda.amp.GradScaler] = None,
    ):
        assert not (amp and scaler is None), "Gradient Scaler is needed for AMP"

        def xform(m: nn.Module) -> nn.Module:
            m.to(memory_format=memory_format)
            return m

        self.model = model # TODO, move the initial quantization after transforming
        self.optimizer = optimizer
        self.loss_fn = loss_fn if loss_fn is not None else None # TODO
        self.amp = amp # TODO
        self.scaler = scaler # TODO

    def train(self):
        self.model.train()
        if self.loss_fn is not None:
            self.loss_fn.train()

    def eval(self):
        self.model.eval()
        if self.loss_fn is not None:
            self.loss_fn.eval()
    
    # TODO: currently, our training framework is not compatible with DistributedDataParallel
    # def distributed(self, gpu_id):
    #     s = torch.cuda.Stream()
    #     s.wait_stream(torch.cuda.current_stream())
    #     with torch.cuda.stream(s):
    #         self.model.module.model = DDP(self.model.module.model, device_ids=[gpu_id], output_device=gpu_id)
    #     torch.cuda.current_stream().wait_stream(s)

    def train_step(self, input, target, step=None):
        if config.debug_memory_model:
            print("========== Init Data Loader ===========")
            init_mem = get_memory_usage(True)
            exp_recorder.record("data_loader", init_mem / GB - exp_recorder.val_dict['model_only'], 2)

        input.requires_grad = True

        # TODO: is the measurement correct?
        if config.debug_memory_model:
            output = self.model(input)
            loss = self.loss_fn(output, target)
            print("========== Before Backward ===========")
            before_backward = get_memory_usage(True)
            act_mem = get_memory_usage() - init_mem - compute_tensor_bytes([loss, output])
            res = "Batch size: %d\tTotal Mem: %.2f MB\tAct Mem: %.2f MB" % (
                    len(output), before_backward / MB, act_mem / MB)
            loss.backward()
            self.optimizer.step()
            del loss
            print("========== After Backward ===========")
            state_mem = after_backward = get_memory_usage(True)
            total_mem = before_backward + (after_backward - init_mem)
            res = "Batch size: %d\tTotal Mem: %.2f MB\tState Mem: %.2f MB" % (
                    len(output), total_mem / MB, state_mem / MB)
            print(res)
            exp_recorder.record("batch_size", len(output))
            exp_recorder.record("total", total_mem / GB, 2)
            exp_recorder.record("state", state_mem / GB, 2)
            exp_recorder.record("activation", act_mem / GB, 2)
            exp_recorder.dump('mem_results.tsv')
            exit()

        with autocast(enabled=self.amp):
            output = self.model(input)
            loss = self.loss_fn(output, target)

        loss.backward()
        self.optimizer.step()
        input.grad = None
        torch.cuda.synchronize()

        return loss

    def validation_step(self, input, target):
        with torch.no_grad():
            output = self.model(input)
            loss = self.loss_fn(output, target)
        return loss, output

    def state_dict(self) -> dict:
        res = {
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }

        return res


def train(
    train_step,
    train_loader,
    lr_scheduler,
    log_fn,
    timeout_handler,
    prof=-1,
    step=0,
):
    interrupted = False

    end = time.time()

    if config.debug_memory_model:
        print("========== Model Only ===========")
        usage = get_memory_usage(True)
        # exp_recorder.record("network", model_and_loss.arch[0])
        # exp_recorder.record("algorithm", 'quantize'
        #         if model_and_loss.arch[1] == 'quantize' else 'exact')
        exp_recorder.record("model_only", usage / GB, 2)

    data_iter = enumerate(train_loader)

    for i, (input, target) in data_iter:
        GOR.start_iteration()
        bs = input.size(0)
        lr = lr_scheduler(i)
        data_time = time.time() - end

        loss = train_step(input, target, step=step + i)
        it_time = time.time() - end

        with torch.no_grad():
            if torch.distributed.is_initialized():
                reduced_loss = utils.reduce_tensor(loss.detach())
            else:
                reduced_loss = loss.detach()

        log_fn(
            compute_ips=utils.calc_ips(bs, it_time - data_time),
            total_ips=utils.calc_ips(bs, it_time),
            data_time=data_time,
            compute_time=it_time - data_time,
            lr=lr,
            loss=reduced_loss.item(),
        )

        end = time.time()
        if prof > 0 and (i + 1 >= prof):
            time.sleep(5)
            break
        if ((i + 1) % 20 == 0) and timeout_handler.interrupted:
            time.sleep(5)
            interrupted = True
            break
        GOR.end_iteration()

    return interrupted


def validate(val_step, val_loader, log_fn, prof=-1, with_loss=True):
    top1 = log.AverageMeter()
    # switch to evaluate mode

    end = time.time()

    data_iter = enumerate(val_loader)

    for i, (input, target) in data_iter:
        bs = input.size(0)
        data_time = time.time() - end

        if with_loss:
            loss, output = val_step(input, target)
        else:
            output = val_step(input)

        with torch.no_grad():
            prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))

            if torch.distributed.is_initialized():
                if with_loss:
                    reduced_loss = utils.reduce_tensor(loss.detach())
                prec1 = utils.reduce_tensor(prec1)
                prec5 = utils.reduce_tensor(prec5)
            else:
                if with_loss:
                    reduced_loss = loss.detach()

        prec1 = prec1.item()
        prec5 = prec5.item()
        infer_result = {
            "top1": (prec1, bs),
            "top5": (prec5, bs),
        }

        if with_loss:
            infer_result["loss"] = (reduced_loss.item(), bs)

        torch.cuda.synchronize()

        it_time = time.time() - end

        top1.record(prec1, bs)

        log_fn(
            compute_ips=utils.calc_ips(bs, it_time - data_time),
            total_ips=utils.calc_ips(bs, it_time),
            data_time=data_time,
            compute_time=it_time - data_time,
            **infer_result,
        )

        end = time.time()
        if (prof > 0) and (i + 1 >= prof):
            time.sleep(5)
            break

    return top1.get_val()


# Train loop {{{
def lbmqt_train_loop(
    trainer: LbmqtTrainer,
    lr_scheduler,
    train_loader,
    train_loader_len,
    val_loader,
    logger,
    best_prec1=0,
    start_epoch=0,
    end_epoch=0,
    early_stopping_patience=-1,
    prof=-1,
    skip_training=False,
    skip_validation=False,
    save_checkpoints=True,
    checkpoint_dir="./",
    checkpoint_filename="checkpoint.pth.tar",
    keep_last_n_checkpoints=0,
):
    checkpointer = utils.Checkpointer(
        last_filename=checkpoint_filename,
        checkpoint_dir=checkpoint_dir,
        keep_last_n=keep_last_n_checkpoints,
    )
    train_metrics = TrainingMetrics(logger)
    val_metrics = ValidationMetrics(logger, 'val')
    training_step = trainer.train_step
    val_step = trainer.validation_step

    prec1 = -1

    if early_stopping_patience > 0:
        epochs_since_improvement = 0

    print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
    with utils.TimeoutHandler() as timeout_handler:
        interrupted = False
        for epoch in range(start_epoch, end_epoch):
            if logger is not None:
                logger.start_epoch()
                GOR.start_epoch()
            if not skip_training:
                if logger is not None:
                    data_iter = logger.iteration_generator_wrapper(
                        train_loader, mode="train"
                    )
                else:
                    data_iter = train_loader

                trainer.train()
                interrupted = train(
                    training_step,
                    data_iter,
                    lambda i: lr_scheduler(trainer.optimizer, i, epoch),
                    train_metrics.log,
                    timeout_handler,
                    prof=prof,
                    step=epoch * train_loader_len,
                )

            if not skip_validation:
                trainer.eval()
                if logger is not None:
                    data_iter = logger.iteration_generator_wrapper(
                        val_loader, mode="val"
                    )
                else:
                    data_iter = val_loader

                prec1, _ = validate(
                    val_step,
                    data_iter,
                    val_metrics.log,
                    prof=prof,
                )

                if prec1 > best_prec1:
                    is_best = True
                    best_prec1 = prec1
                else:
                    is_best = False
            else:
                is_best = False
                best_prec1 = 0

            if logger is not None:
                logger.end_epoch()
                GOR.end_epoch()

            if save_checkpoints and (
                not torch.distributed.is_initialized()
                or torch.distributed.get_rank() == 0
            ):
                checkpoint_state = {
                    "epoch": epoch + 1,
                    "best_prec1": best_prec1,
                    **trainer.state_dict(),
                }
                checkpointer.save_checkpoint(
                    checkpoint_state,
                    is_best,
                    filename=f"checkpoint_{epoch:04}.pth.tar",
                )

            if early_stopping_patience > 0:
                if not is_best:
                    epochs_since_improvement += 1
                else:
                    epochs_since_improvement = 0
                if epochs_since_improvement >= early_stopping_patience:
                    break
            if interrupted:
                break


# }}}
