import os
import sys
import time
import logging
from argparse import Namespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler

from timm.utils import AverageMeter

sys.path.append(os.pardir)
from utils import (
    FeatureStorage, 
    register_feature_cache_hooks, 
    accuracy,
    reset_optimizer
)

try:
    import wandb
    wandb_enabled = True
except:
    wandb_enabled = False

logger = logging.getLogger(__name__)


def mse_loss_norm(x1, x2, eps=1e-8):
    return ((x1 - x2) ** 2).mean() / ((x2 ** 2).mean() + eps)

def kl_div_with_logits(student_logits, teacher_logits, temp=1.0):
    return temp ** 2 * F.kl_div(
        input=F.log_softmax(student_logits / temp, dim=-1),
        target=F.log_softmax(teacher_logits / temp, dim=-1),
        log_target=True,
        reduction="sum",
    )


BASE_LOSS_FUNCTIONS = {
    'cross_entropy': F.cross_entropy,
    'mse_loss': F.mse_loss
}

OUTPUT_KD_LOSS_FUNCTIONS = {
    'kl_div': kl_div_with_logits,
    'mse_loss': F.mse_loss
}

FEAT_KD_LOSS_FUNCTIONS = {
    'mse_loss': F.mse_loss,
    'mse_loss_norn': mse_loss_norm
}


def log_train_stats(train_stats):
    logger.info('-' * 10)
    for loss_name, loss_value in train_stats.items():
        logger.info(f"{loss_name.capitalize()}: {loss_value:.2e}")
    logger.info('-' * 10)


def log_eval_stats(eval_stats):
    logger.info('-' * 10)
    logger.info(f"Loss: {eval_stats['loss']:.2e}")
    logger.info(f"Acc1: {(100 * eval_stats['acc1']):.2f}%")
    logger.info('-' * 10)


@torch.no_grad()
def evaluate(model, val_loader, loss_fn, device, amp):
    model.eval()
    # create meters
    loss_m = AverageMeter()
    acc1_m = AverageMeter() # TODO add other options

    for (inputs, targets) in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.autocast(device_type=device, enabled=amp):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
        # get accuracies
        acc1 = accuracy(outputs.float(), targets.float())
        # update stats
        acc1_m.update(acc1, len(inputs))
        loss_m.update(loss.item(), len(inputs))
        
    return {'loss': loss_m.avg, 'acc1': acc1_m.avg}


class Trainer:

    def __init__(
        self, 
        model: nn.Module, 
        args: Namespace,
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer, 
        pruners = [],
        lr_scheduler = None,
        teacher_model = None,
        device = 'cuda'
    ):
        self.model = model
        self.args = args
        # data
        self.train_loader = train_loader
        self.val_loader = val_loader
        # optimization
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.scaler = GradScaler(enabled=args.amp)
        # pruning
        self.pruners = pruners
        # distillation
        self.teacher_model = teacher_model
        self.device = device
        # init feature storages
        self.features = FeatureStorage()
        self.teacher_features = FeatureStorage()
        # register hooks
        if self.teacher_model:
            self.hooks = register_feature_cache_hooks(model, args.feat_names, self.features)
            self.teacher_hooks = register_feature_cache_hooks(teacher_model, args.feat_names, self.teacher_features)

        self.build_loss_functions()

    def build_loss_functions(self):
        # build base loss
        self.base_loss_fn = BASE_LOSS_FUNCTIONS[self.args.base_loss]
        self.output_kd_loss_fn = OUTPUT_KD_LOSS_FUNCTIONS[self.args.output_kd_loss]
        self.feat_kd_loss_fn = FEAT_KD_LOSS_FUNCTIONS[self.args.feat_kd_loss]

    def train_step(self, batch):
        self.model.train()
        inputs, targets = batch
        inputs, targets = inputs.to(self.device), targets.to(self.device)

        # set default values
        train_loss = 0
        train_loss_base = 0
        train_loss_kd_output = 0
        train_loss_kd_feat = 0

        # turn on feature loss
        self.features.enabled = True
        self.teacher_features.enabled = True

        autocast_context = torch.autocast(device_type=self.device, enabled=self.args.amp)

        # get base loss
        with autocast_context:
            outputs = self.model(inputs)
        train_loss_base = self.base_loss_fn(outputs, targets)

        # get kd losses
        if self.teacher_model:
            with autocast_context:
                with torch.no_grad():
                    teacher_outputs = self.teacher_model(inputs)
                train_loss_kd_output = self.output_kd_loss_fn(outputs, teacher_outputs)

                for feat_name in self.features:
                    train_loss_kd_feat += self.feat_kd_loss_fn(
                        self.features[feat_name], 
                        self.teacher_features[feat_name]
                    )

        # Compute total loss
        train_loss = (
            self.args.lambda_base * train_loss_base +
            self.args.lambda_kd_output * train_loss_kd_output +
            self.args.lambda_kd_feat * train_loss_kd_feat
        )

        ## make backward pass
        self.scaler.scale(train_loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        # # zero grad
        self.optimizer.zero_grad(set_to_none=True)
        # # make scheduler step
        if self.lr_scheduler:
            self.lr_scheduler.step()

        # # turn off feature loss
        self.features.enabled = False
        self.teacher_features.enabled = False

        return dict(
            loss=train_loss,
            loss_base=train_loss_base,
            loss_kd_output=train_loss_kd_output,
            loss_kd_feat=train_loss_kd_feat
        )

    @torch.no_grad()
    def evaluate(self):
        return evaluate(self.model, self.val_loader, self.base_loss_fn, self.device, self.args.amp)

    def train(self):
        # TODO add checkpoint resuming
        global_step = 0

        # evaluate before training
        if self.args.eval_before_training:
            logger.info('Evaluation before training')
            eval_stats = self.evaluate()
            log_eval_stats(eval_stats)
            # log to W&B
            if self.args.log_wandb:
                wandb.log(
                    {f'val/{k}': v for k, v in eval_stats.items()},
                    step=global_step
                )

        # init loss meters
        loss_meters = dict(
            loss=AverageMeter(),
            loss_base=AverageMeter(),
            loss_kd_output=AverageMeter(),
            loss_kd_feat=AverageMeter()
        )

        for epoch in range(0, self.args.num_train_epochs):
            for step, batch in enumerate(self.train_loader):
                # pruning is done before train step
                any_pruner_updated = False
                t_s = time.perf_counter()
                for pruner in self.pruners:
                    is_pruner_updated = pruner.is_update_step()
                    any_pruner_updated = any_pruner_updated or is_pruner_updated
                    pruner.step()
                t_e = time.perf_counter()

                # reset optimizer states
                if any_pruner_updated:
                    logger.info(f'Pruning step took {(t_e - t_s):.2f} seconds')
                    logger.info(f'Resetting optimizer after pruning on step {global_step}')
                    reset_optimizer(self.optimizer)

                train_stats = self.train_step(batch)
                for loss_name, loss_meter in loss_meters.items():
                    loss_meter.update(train_stats[loss_name])

                if global_step % self.args.logging_steps == 0:
                    logger.info(f'Train stats on step {global_step}')
                    log_train_stats({
                        loss_name: loss_meter.avg for loss_name, loss_meter in loss_meters.items()
                    })
                    # log to W&B
                    if self.args.log_wandb:
                        wandb.log({'lr': self.optimizer.param_groups[0]['lr']})
                        wandb.log(
                            {f'train/{loss_name}': loss_meter.avg for loss_name, loss_meter in loss_meters.items()},
                            step=global_step
                        )
                    # reset meters
                    for _, loss_meter in loss_meters.items():
                        loss_meter.reset()

                if global_step > 0 and global_step % self.args.eval_steps == 0:

                    logger.info(f'Evaluation on step {global_step}')
                    eval_stats = self.evaluate()
                    log_eval_stats(eval_stats)
                    # log to W&B
                    if self.args.log_wandb:
                        wandb.log(
                            {f'val/{k}': v for k,v in eval_stats.items()},
                            step=global_step
                        )

                global_step += 1
                # first exit statement
                if global_step == self.args.num_train_steps:
                    break 
            # second exit statement
            if global_step == self.args.num_train_steps:
                break

        logger.info('Evaluation after training')
        eval_stats = self.evaluate()
        log_eval_stats(eval_stats)
        if self.args.log_wandb:
            wandb.log(
                {f'val/{k}': v for k, v in eval_stats.items()},
                step=global_step
            )

        # save last model
        torch.save(self.model.state_dict(), os.path.join(self.args.output_dir, 'last_checkpoint.pth'))
