from . import utils
from . import metrics
from . import evaluator
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import logging
import copy
import os
from collections import defaultdict
from tqdm import tqdm
import numpy as np


def set_weight_decay(
    model: torch.nn.Module,
    weight_decay: float,
    norm_weight_decay,
    norm_classes=None,
    custom_keys_weight_decay=None,
):
    if not norm_classes:
        norm_classes = [
            torch.nn.modules.batchnorm._BatchNorm,
            torch.nn.LayerNorm,
            torch.nn.GroupNorm,
            torch.nn.modules.instancenorm._InstanceNorm,
            torch.nn.LocalResponseNorm,
        ]
    norm_classes = tuple(norm_classes)

    params = {
        "other": [],
        "norm": [],
    }
    params_weight_decay = {
        "other": weight_decay,
        "norm": norm_weight_decay,
    }
    custom_keys = []
    if custom_keys_weight_decay is not None:
        for key, weight_decay in custom_keys_weight_decay:
            params[key] = []
            params_weight_decay[key] = weight_decay
            custom_keys.append(key)

    def _add_params(module, prefix=""):
        for name, p in module.named_parameters(recurse=False):
            if not p.requires_grad:
                continue
            is_custom_key = False
            for key in custom_keys:
                target_name = (
                    f"{prefix}.{name}" if prefix != "" and "." in key else name
                )
                if key == target_name:
                    params[key].append(p)
                    is_custom_key = True
                    break
            if not is_custom_key:
                if norm_weight_decay is not None and isinstance(module, norm_classes):
                    params["norm"].append(p)
                else:
                    params["other"].append(p)

        for child_name, child_module in module.named_children():
            child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
            _add_params(child_module, prefix=child_prefix)

    _add_params(model)

    param_groups = []
    for key in params:
        if len(params[key]) > 0:
            param_groups.append(
                {"params": params[key], "weight_decay": params_weight_decay[key]}
            )
    return param_groups


def calculate_fairness(old_acc, new_acc, fairness_type=None):
    classes = set(old_acc.keys()) | set(new_acc.keys())
    relative_differences = []
    for cls in classes:
        old_value = old_acc.get(cls, 0)
        new_value = new_acc.get(cls, 0)
        if old_value != 0:
            relative_difference = (new_value - old_value) / old_value
        else:
            relative_difference = 0 if new_value == 0 else 1
        relative_differences.append(relative_difference)

    if fairness_type == "max_min":
        fairness = max(relative_differences) - min(relative_differences)
    else:
        raise ValueError(
            f"Invalid fairness type {fairness_type}. Only 'var' and 'max_min' are supported."
        )

    fairness_score = -1 * fairness
    return fairness_score


def fairness_eval(model, test_loader, device=0):
    model.eval()
    correct_per_class = defaultdict(int)
    total_per_class = defaultdict(int)
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    loss_fn = torch.nn.CrossEntropyLoss()
    total_loss = 0

    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Evaluating", leave=False):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output, target)
            total_loss += loss.item() * data.size(0)

            _, pred_top5 = output.topk(5, dim=1, largest=True, sorted=True)
            correct_top5 += (pred_top5 == target.unsqueeze(1)).sum().item()

            pred_top1 = pred_top5[:, 0]  # Top-1 prediction is the first column of top-5
            correct_top1 += (pred_top1 == target).sum().item()

            matches = (pred_top1 == target).cpu().numpy()
            total += target.size(0)

            for label, match in zip(target.cpu().numpy(), matches):
                correct_per_class[label] += int(match)
                total_per_class[label] += 1

    test_acc_top1 = 100 * correct_top1 / total
    test_acc_top5 = 100 * correct_top5 / total
    test_loss = total_loss / total

    per_class_acc = {
        cls: 100 * (correct_per_class[cls] / total_per_class[cls])
        if total_per_class[cls] > 0
        else 0
        for cls in set(total_per_class.keys())
    }

    return test_acc_top1, test_acc_top5, test_loss, per_class_acc


def eval(model, test_loader, device=0):
    model.eval()
    my_evaluator = evaluator.classification_evaluator(test_loader, device)
    result = my_evaluator(model)
    test_acc_top1, test_acc_top5 = result["Acc"]
    test_loss = result["Loss"]

    return test_acc_top1, test_acc_top5, test_loss


def train_model(
    model,
    epochs,
    lr,
    lr_step_size,
    lr_warmup_epochs,
    train_loader,
    test_loader,
    criterion,
    save_dir,
    device,
    args,
    pruner=None,
    lr_decay_milestones=None,
    save_every=10,
    return_best=False,
    fairness_eval_flag=False,
    initial_per_class_acc=None,
    initial_acc_top1=None,
    fairness_type=None,
    scalar=1,
):
    logger = logging.getLogger("train_logger")
    logger2 = logging.getLogger("result_logger")

    weight_decay = args.weight_decay if pruner is None else 0
    bias_weight_decay = args.bias_weight_decay if pruner is None else 0
    norm_weight_decay = args.norm_weight_decay if pruner is None else 0

    custom_keys_weight_decay = []
    if bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", bias_weight_decay))

    parameters = set_weight_decay(
        model,
        weight_decay,
        norm_weight_decay=norm_weight_decay,
        custom_keys_weight_decay=custom_keys_weight_decay
        if len(custom_keys_weight_decay) > 0
        else None,
    )

    opt_name = args.opt_name.lower()
    if opt_name == "sgd":
        logger.info("Using SGD")
        optimizer = torch.optim.SGD(
            parameters,
            lr=lr,
            momentum=args.momentum,
            weight_decay=weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        logger.info("Using RMSprop")
        optimizer = torch.optim.RMSprop(
            parameters,
            lr=lr,
            momentum=args.momentum,
            weight_decay=weight_decay,
            eps=0.0316,
            alpha=0.9,
        )
    elif opt_name == "adamw":
        logger.info("Using AdamW")
        optimizer = torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay)
    else:
        raise RuntimeError(
            f"Invalid optimizer {opt_name}. Only SGD, RMSprop and AdamW are supported."
        )

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    if args.lr_scheduler_name != "None":
        lr_scheduler_name = args.lr_scheduler_name.lower()
        if lr_scheduler_name == "steplr":
            main_lr_scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=lr_step_size, gamma=args.lr_decay_gamma
            )
        elif lr_scheduler_name == "multisteplr":
            logger.info("Using multisteplr scheduler")
            milestones = [int(ms) for ms in lr_decay_milestones.split(",")]
            main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=args.lr_decay_gamma
            )
        elif lr_scheduler_name == "cosineannealinglr":
            logger.info("Using cosineannealinglr scheduler")
            main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=epochs - lr_warmup_epochs, eta_min=args.lr_min
            )
        elif lr_scheduler_name == "exponentiallr":
            logger.info("Using exponentiallr scheduler")
            main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=args.lr_decay_gamma
            )
        else:
            raise RuntimeError(
                f"Invalid lr scheduler '{args.lr_scheduler_name}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
                "are supported."
            )

        if lr_warmup_epochs > 0:
            logger.info("LR warmup")
            if args.lr_warmup_method == "linear":
                logger.info("Linear LR warmup")
                warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                    optimizer,
                    start_factor=args.lr_warmup_decay,
                    total_iters=lr_warmup_epochs,
                )
            elif args.lr_warmup_method == "constant":
                logger.info("constant LR warmup")
                warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                    optimizer, factor=args.lr_warmup_decay, total_iters=lr_warmup_epochs
                )
            else:
                raise RuntimeError(
                    f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
                )
            lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
                optimizer,
                schedulers=[warmup_lr_scheduler, main_lr_scheduler],
                milestones=[lr_warmup_epochs],
            )
        else:
            lr_scheduler = main_lr_scheduler
    else:
        logger.info("Scheduler is not used")

    best_acc_top1 = -1
    best_epoch = -1
    test_acc_top1_list = []
    test_acc_top5_list = []
    test_loss_list = []
    if fairness_eval_flag:
        fairness_score_list = []
        test_acc_top1_change_ratio_list = []
        fitness_score_list = []

    for epoch in range(epochs):
        model.train()

        for i, (data, target) in enumerate(train_loader):
            data, target = data.cuda(device), target.cuda(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=scaler is not None):
                out = model(data)
                loss = criterion(out, target)

            if scaler is not None:
                scaler.scale(loss).backward()
                if args.clip_grad_norm is not None:
                    # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if args.clip_grad_norm is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                optimizer.step()

            if i % 10 == 0:
                logger.info(
                    "Epoch {:d}/{:d}, iter {:d}/{:d}, loss={:.4f}".format(
                        epoch + 1, epochs, i + 1, len(train_loader), loss.item()
                    )
                )

        if args.lr_scheduler_name != "None":
            lr_scheduler.step()

        if fairness_eval_flag:
            test_acc_top1, test_acc_top5, test_loss, per_class_acc = fairness_eval(
                model, test_loader, device=device
            )
            fairness_score = calculate_fairness(
                initial_per_class_acc, per_class_acc, fairness_type=fairness_type
            )
            test_acc_top1_change_ratio = (
                test_acc_top1 - initial_acc_top1
            ) / initial_acc_top1
            fitness_score = test_acc_top1_change_ratio + scalar * fairness_score
        else:
            test_acc_top1, test_acc_top5, test_loss = eval(
                model, test_loader, device=device
            )

        if fairness_eval_flag:
            logger.info(
                "Epoch {:d}/{:d}, Test Acc (Top-1)={:.4f}, Fairness Score={:.4f}, Test Acc (Top-5)={:.4f}, Test Loss={:.4f}, Fitness Score={:.4f}".format(
                    epoch + 1,
                    epochs,
                    test_acc_top1,
                    fairness_score,
                    test_acc_top5,
                    test_loss,
                    fitness_score,
                )
            )
            logger.info(f"Per class accuracy: {per_class_acc}")

            logger2.info(
                "Epoch {:d}/{:d}, Test Acc (Top-1)={:.4f}, Fairness Score={:.4f}, Test Acc (Top-5)={:.4f}, Test Loss={:.4f}, Fitness Score={:.4f}".format(
                    epoch + 1,
                    epochs,
                    test_acc_top1,
                    fairness_score,
                    test_acc_top5,
                    test_loss,
                    fitness_score,
                )
            )
        else:
            logger.info(
                "Epoch {:d}/{:d}, Test Acc (Top-1)={:.4f}, Test Acc (Top-5)={:.4f}, Test Loss={:.4f}".format(
                    epoch + 1, epochs, test_acc_top1, test_acc_top5, test_loss
                )
            )
            logger2.info(
                "Epoch {:d}/{:d}, Test Acc (Top-1)={:.4f}, Test Acc (Top-5)={:.4f}, Test Loss={:.4f}".format(
                    epoch + 1, epochs, test_acc_top1, test_acc_top5, test_loss
                )
            )
        test_acc_top1_list.append(test_acc_top1)
        test_acc_top5_list.append(test_acc_top5)
        test_loss_list.append(test_loss)

        if fairness_eval_flag:
            fairness_score_list.append(fairness_score)
            test_acc_top1_change_ratio_list.append(test_acc_top1_change_ratio)
            fitness_score_list.append(fitness_score)

        if best_acc_top1 < test_acc_top1:
            best_acc_top1 = test_acc_top1
            best_epoch = epoch

            best_model = copy.deepcopy(model)
            best_model_dict = utils.save_model(
                model, os.path.join(save_dir, "best_model.pth"), args=args
            )

        if epoch > 0 and (epoch + 1) % save_every == 0:
            ckpt_model_dict = utils.save_model(
                model,
                os.path.join(save_dir, "checkpoint_ep_{}.pth".format(epoch + 1)),
                args=args,
            )

        last_model_dict = utils.save_model(
            model, os.path.join(save_dir, "last_model.pth"), args=args
        )

    logger.info("Best Acc (Top-1)=%.4f, epoch=%d" % (best_acc_top1, best_epoch + 1))
    logger.info("Last Acc (Top-1)=%.4f, epoch=%d" % (test_acc_top1, epoch + 1))
    logger.info("Last Acc (Top-5)=%.4f, epoch=%d" % (test_acc_top5, epoch + 1))

    logger2.info("Best Acc (Top-1)=%.4f, epoch=%d" % (best_acc_top1, best_epoch + 1))
    logger2.info("Last Acc (Top-1)=%.4f, epoch=%d" % (test_acc_top1, epoch + 1))
    logger2.info("Last Acc (Top-5)=%.4f, epoch=%d" % (test_acc_top5, epoch + 1))

    logger.info(f"Test Acc Top1 List ={test_acc_top1_list}")
    logger.info(f"Test Acc Top5 List ={test_acc_top5_list}")
    logger.info(f"Test Loss List ={test_loss_list}")

    if fairness_eval_flag:
        logger.info(f"fairness_score_list={fairness_score_list}")
        logger.info(
            f"test_acc_top1_change_ratio_list={test_acc_top1_change_ratio_list}"
        )
        logger.info(f"fitness_score_list={fitness_score_list}")

    logger.info("\n")

    logger2.info("\n")

    if return_best:
        logger.info("return best model")
        return best_model, best_model_dict
    else:
        logger.info("return last model")
        return model, last_model_dict
