import functools
import logging
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.cuda.amp import autocast, GradScaler

from codebase.torchutils.distributed import world_size
from codebase.torchutils.metrics import AccuracyMetric, AverageMetric, EstimatedTimeArrival
from codebase.torchutils.common import GradientAccumulator
from codebase.torchutils.common import ThroughputTester, time_enumerate

import smtplib
from email.message import EmailMessage

_logger = logging.getLogger(__name__)

scaler = None


def trainable_network_weights(net, requires_grad=False):
    net.train(mode=requires_grad)
    for name, param in net.named_parameters():
        if 'S' not in name:
            param.requires_grad = requires_grad


def trainable_score_weights(net, requires_grad=False):
    for name, param in net.named_parameters():
        if 'S' in name:
            param.requires_grad = requires_grad


def _run_one_epoch(is_training: bool,
                   is_last: bool,
                   epoch: int,
                   model: nn.Module,
                   loader: data.DataLoader,
                   criterion: nn.modules.loss._Loss,
                   optimizer: optim.Optimizer,
                   scheduler: optim.lr_scheduler._LRScheduler,
                   use_amp: bool,
                   accmulated_steps: int,
                   device: str,
                   memory_format: str,
                   log_interval: int,
                   is_score_training: bool,
                   landa: float,
                   loss_normalizer_h: list,
                   output_dir: str):
    phase = "train" if is_training else "eval"
    score_training = 0 if is_score_training else 1
    trainable_network_weights(model, not is_score_training)
    trainable_score_weights(model, is_score_training)
    model.train(mode=is_training)


    global scaler
    if scaler is None:
        scaler = GradScaler(enabled=use_amp and is_training)

    gradident_accumulator = GradientAccumulator(steps=accmulated_steps, enabled=is_training)

    time_cost_metric = AverageMetric("time_cost")
    loss_metric = AverageMetric("loss")
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    eta = EstimatedTimeArrival(len(loader))
    speed_tester = ThroughputTester()

    if is_training and scheduler is not None:
        scheduler.step(epoch)

    lr = optimizer.param_groups[0]['lr']
    _logger.info(f"{phase.upper()} start, epoch={epoch:04d}, lr={lr:.6f}")

    loss_coefficients = [1] + [2 ** loss_normalizer_h[0]] * 3 * 3 + [2 ** loss_normalizer_h[1]] * 4 * 3 + [2 ** loss_normalizer_h[2]] * 6 * 3 + [2 ** loss_normalizer_h[3]] * 3 * 3
    for time_cost, iter_, (inputs, targets) in time_enumerate(loader, start=1):
        inputs = inputs.to(device=device, non_blocking=True, memory_format=memory_format)
        targets = targets.to(device=device, non_blocking=True)

        with torch.set_grad_enabled(mode=is_training):
            with autocast(enabled=use_amp and is_training):
                scores, outputs = model(inputs, is_score_training=score_training)
                loss: torch.Tensor = criterion(outputs, targets)
                if is_score_training:
                    index = 0
                    for score in scores:
                        l1_norm = torch.abs(torch.norm(score, p=1))
                        loss += l1_norm * landa * loss_coefficients[index]
                        index += 1

        gradident_accumulator.backward_step(model, loss, optimizer, scaler)

        time_cost_metric.update(time_cost)
        loss_metric.update(loss)
        accuracy_metric.update(outputs, targets)
        eta.step()
        speed_tester.update(inputs)

        if iter_ % log_interval == 0 or iter_ == len(loader):
            _logger.info(", ".join([
                phase.upper(),
                f"epoch={epoch:04d}",
                f"iter={iter_:05d}/{len(loader):05d}",
                f"fetch data time cost={time_cost_metric.compute()*1000:.2f}ms",
                f"fps={speed_tester.compute()*world_size():.0f} images/s",
                f"{loss_metric}",
                f"{accuracy_metric}",
                f"{eta}",
            ]))
            time_cost_metric.reset()
            speed_tester.reset()

        if is_last and not is_training and len(inputs) < 128:
            save_and_print(scores, output_dir, landa, loss_normalizer_h, email=True)

    if is_last and not is_training:
        save_and_print(scores, output_dir, landa, loss_normalizer_h, email=False)

    return {
        f"{phase}/lr": lr,
        f"{phase}/loss": loss_metric.compute(),
        f"{phase}/top1_acc": accuracy_metric.at(1).rate,
        f"{phase}/top5_acc": accuracy_metric.at(5).rate,
    }


def save_and_print(scores, output_dir, landa, loss_normalizer_h, email: bool):
    pickle.dump(scores, open(output_dir / "Res50_scores.p", "wb"))

    sum_scores = []
    for s in scores:
        sum_scores.append(torch.sum(s).item())
    kernels = [49] + [1, 9, 1] * 16
    filters = [64] + [64, 64, 256] * 3 + [128, 128, 512] * 4 + [256, 256, 1024] * 6 + [512, 512, 2048] * 3
    image_sizes = [224] + [56] * 3 * 3 + [56] * 4 * 3 + [28] * 6 * 3 + [14] * 3 * 3

    # Calculating the Pruned Params
    params = 0
    original_params = 0
    for i in range(len(sum_scores)):
        if i == 0:
            params += 3 * sum_scores[i] * kernels[i]
            original_params += 3 * filters[i] * kernels[i]
        else:
            params += sum_scores[i - 1] * sum_scores[i] * kernels[i]
            original_params += filters[i - 1] * filters[i] * kernels[i]
    params += 1000 * sum_scores[-1]
    original_params += 1000 * filters[-1]

    # Calculating the Pruned Flops
    flops = 0
    original_flops = 0
    for i in range(len(sum_scores)):
        if i == 0:
            flops += 3 * sum_scores[i] * kernels[i] * image_sizes[i] * image_sizes[i]
            original_flops += 3 * filters[i] * kernels[i] * image_sizes[i] * image_sizes[i]
        else:
            flops += sum_scores[i - 1] * sum_scores[i] * kernels[i] * image_sizes[i] * image_sizes[i]
            original_flops += filters[i - 1] * filters[i] * kernels[i] * image_sizes[i] * image_sizes[i]
    flops += 1000 * sum_scores[-1] * 7 * 7
    original_flops += 1000 * filters[-1] * 7 * 7

    print('Pruned Params:', 1. - params / original_params)
    print('Pruned Flops:', 1. - flops / original_flops)

    if email:
        msg = EmailMessage()
        msg['Subject'] = f'Report Resnet50:lambda= {landa}, loss_normalizer= {loss_normalizer_h}'
        msg['From'] = "MIT server"
        msg['To'] = "pmkiasari@gmail.com"
        msg.set_content(f'Pruned Params: {1. - params / original_params} \n Pruned Flops: {1. - flops / original_flops}')
        with open(output_dir / "Res50_scores.p", "rb") as f:  # The report file
            msg.add_attachment(f.read(), maintype="application", subtype="p", filename=f.name)

        with smtplib.SMTP_SSL('smtp.gmail.com', 465) as s:
            s.login('butunulker@gmail.com', "Temp@com123")
            s.sendmail("The report", ["Z.Babaiee@gmail.com"], msg.as_string())


train_one_epoch = functools.partial(_run_one_epoch, is_training=True)
evaluate_one_epoch = functools.partial(_run_one_epoch, is_training=False)
evaluate_last_epoch = functools.partial(_run_one_epoch, is_last=True, is_training=False)

