import argparse
import datetime
import json
import logging
import os
import pickle
import random
import time
from torch import nn

import numpy as np
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import wandb
# wandb.init(project="iqa-ijcaj-11-23")

import rednet
import timm

from scipy import stats
from timm.utils import AverageMeter  # accuracy
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

from IQA.build_NAR import IQA_build_loader_NAR
from config_NAR import get_config
from logger import create_logger
from lr_scheduler import build_scheduler
from models.build import build_model
from optimizer import build_optimizer,build_optimizer_shadow
from utils import (
    NativeScalerWithGradNormCount,
    auto_resume_helper,
    load_checkpoint,
    load_pretrained,
    reduce_tensor,
    save_checkpoint,
)

import torch.nn.functional as F

def parse_option():
    parser = argparse.ArgumentParser(
        "Swin Transformer training and evaluation script", add_help=False
    )
    parser.add_argument(
        "--cfg",
        type=str,
        required=True,
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs="+",
    )

    # easy config modification
    parser.add_argument("--batch-size", type=int, help="batch size for single GPU")
    parser.add_argument("--data-path", type=str, help="path to dataset")
    parser.add_argument(
        "--zip",
        action="store_true",
        help="use zipped dataset instead of folder dataset",
    )
    parser.add_argument(
        "--cache-mode",
        type=str,
        default="part",
        choices=["no", "full", "part"],
        help="no: no cache, "
             "full: cache all data, "
             "part: sharding the dataset into nonoverlapping pieces and only cache one piece",
    )
    parser.add_argument(
        "--pretrained",
        help="pretrained weight from checkpoint, could be imagenet22k pretrained weight",
    )
    parser.add_argument("--resume", help="resume from checkpoint")
    parser.add_argument(
        "--accumulation-steps", type=int, help="gradient accumulation steps"
    )
    parser.add_argument(
        "--tensorboard",
        action="store_true",
        help="Using tensorboard to track the process",
    )
    parser.add_argument(
        "--use-checkpoint",
        action="store_true",
        help="whether to use gradient checkpointing to save memory",
    )
    parser.add_argument(
        "--disable_amp", action="store_true", help="Disable pytorch amp"
    )
    parser.add_argument(
        "--amp-opt-level",
        type=str,
        choices=["O0", "O1", "O2"],
        help="mixed precision opt level, if O0, no amp is used (deprecated!)",
    )
    parser.add_argument(
        "--output",
        default="output",
        type=str,
        metavar="PATH",
        help="root of output folder, the full path is <output>/<model_name>/<tag> (default: output)",
    )
    parser.add_argument("--tag", help="tag of experiment")
    parser.add_argument("--eval", action="store_true", help="Perform evaluation only")
    parser.add_argument(
        "--throughput", action="store_true", help="Test throughput only"
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Use torchinfo to show the flow of tensor in model",
    )
    parser.add_argument(
        "--repeat", action="store_true", help="Test model for publications"
    )
    parser.add_argument("--rnum", type=int, help="Repeat num")
    # distributed training
    # os.environ["LOCAL_RANK"] = '0'
    # os.environ["RANK"] = '0'
    # os.environ["WORLD_SIZE"] = '1'
    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '14326'
    # os.environ["WANDB_API_KEY"] = '348931bfeb6a90398b11aac60a8dece8c8807a6e'  # 将引号内的+替换成自己在wandb上的一串值
    os.environ["WANDB_MODE"] = "offline"  # 离线  （此行代码不用修改）

    local_rank = int(os.environ["LOCAL_RANK"])
    args, unparsed = parser.parse_known_args()

    config = get_config(args, local_rank)
    return args, config


def main(config):
    if dist.get_rank() == 0:
        group_name = config.TAG
        wandb_name = group_name + "_" + str(config.EXP_INDEX)
        os.makedirs(wandb_dir := (os.path.join(config.OUTPUT, "wandb")), exist_ok=True)
        wandb_runner = wandb.init(
            project="",
            # entity="",
            group=group_name,
            name=wandb_name,
            config={
                "epochs": config.TRAIN.EPOCHS,
                "batch_size": config.DATA.BATCH_SIZE,
                "patch_num": config.DATA.PATCH_NUM,
            },
            dir=wandb_dir,
            reinit=True,
        )
        wandb_runner.log({"Validating SRCC": 0.0, "Validating PLCC": 0.0, "Epoch": 0})
    else:
        wandb_runner = None

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")

    # 学生模型建立在这里
    s_model = build_model(config)

    # inv_teacher = redNet().cuda()
    # inv_teacher = torch.nn.SyncBatchNorm.convert_sync_batchnorm(inv_teacher)
    # inv_teacher.load_state_dict(torch.load('./EffnetPretraining.pth'))

    logger.info(str(s_model))

    n_parameters = sum(p.numel() for p in s_model.parameters() if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(s_model, "flops"):
        flops = s_model.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    s_model.cuda()
    s_model_without_ddp = s_model

    if config.DEBUG_MODE:
        summary(
            s_model,
            (config.DATA.BATCH_SIZE, 3, config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
        )
        return
    (
        dataset_train,
        dataset_val,
        data_loader_train,
        data_loader_val,
        mixup_fn,
    ) = IQA_build_loader_NAR(config)

    s_model = torch.nn.parallel.DistributedDataParallel(
        s_model,
        device_ids=[config.LOCAL_RANK],
        broadcast_buffers=False,
        find_unused_parameters=True,
    )
    if dist.get_rank() == 0:
        wandb_runner.watch(s_model)


    loss_scaler = NativeScalerWithGradNormCount()


    max_plcc = 0.0
    max_srcc = 0.0
    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f"auto resuming from {resume_file}")
        else:
            logger.info(f"no checkpoint found in {config.OUTPUT}, ignoring auto resume")

    if config.MODEL.RESUME:
        max_plcc, epochs = load_checkpoint(
            config, s_model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger
        )
        srcc, plcc, loss = validate(
            config, data_loader_val, s_model, epochs, len(dataset_val)
        )
        logger.info(
            f"SRCC and PLCC of the network on the {len(dataset_val)} test images: {srcc:.6f}, {plcc:.6f}"
        )
        if config.EVAL_MODE:
            return

    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, s_model_without_ddp, logger)
        srcc, plcc, loss = validate(config, data_loader_val, s_model)
        logger.info(
            f"SRCC and PLCC of the network on the {len(dataset_val)} test images: {srcc:.6f}, {plcc:.6f}"
        )

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, s_model, logger)
        return

    if config.TENSORBOARD:
        writer = SummaryWriter(log_dir=config.OUTPUT)
    else:
        writer = None
    logger.info("Start training")
    start_time = time.time()
    # weight = {
    #     "logit_loss": 1 / 2.0,
    #     # "INN_loss": 1 / 3.0,
    #     # "NAR_loss": 1 / 3.0,
    #     "Fea_loss": 1 / 2.0,
    # }
    # init_loss_logit = None
    # init_loss_fea = None
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        # data_loader_train.sampler.set_epoch(epoch)
        # # init_loss_logit, init_loss_fea =
        # train_one_epoch(
        #     config,
        #     module_list,
        #     trainable_list,
        #     criterion,
        #     data_loader_train,
        #     # data_loader_train_FR,
        #     optimizer,
        #     optimizer_shadow_cnn,
        #     optimizer_shadow_inn,
        #     epoch,
        #     mixup_fn,
        #     lr_scheduler,
        #     loss_scaler,
        #     writer,
        #     wandb_runner,
        #     # weight,
        #     # init_loss_logit,
        #     # init_loss_fea,
        # )
        # if dist.get_rank() == 0 and (
        #         epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)
        # ):
        #     save_checkpoint(
        #         config,
        #         epoch,
        #         s_model_without_ddp,
        #         max_plcc,
        #         optimizer,
        #         lr_scheduler,
        #         loss_scaler,
        #         logger,
        #     )
        srcc, plcc, loss = validate(
            config,
            data_loader_val,
            s_model,
            epoch,
            len(dataset_val),
            tensorboard=writer,
            wandb_runner=wandb_runner,
        )
        if config.TENSORBOARD == True:
            writer.add_scalars(
                "Validate Metrics",
                {"SRCC": srcc, "PLCC": plcc},
                epoch,
            )
        if dist.get_rank() == 0:
            wandb_runner.log(
                {"Validating SRCC": srcc, "Validating PLCC": plcc, "Epoch": epoch + 1}
            )
        logger.info(
            f"SRCC and PLCC of the network on the {len(dataset_val)} test images: {srcc:.6f}, {plcc:.6f}"
        )
        if plcc >= max_plcc:
            max_plcc = max(max_plcc, plcc)
            max_srcc = srcc
        elif plcc < 0:
            max_srcc = 0
        logger.info(f"Max PLCC: {max_plcc:.6f} Max SRCC: {max_srcc:.6f}")
        if dist.get_rank() == 0:
            wandb_runner.summary["Best PLCC"], wandb_runner.summary["Best SRCC"] = (
                max_plcc,
                max_srcc,
            )
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info("Training time {}".format(total_time_str))
    writer.close()
    if dist.get_rank() == 0:
        wandb_runner.alert(
            title="Run Finished",
            text=f"Max PLCC: {max_plcc:.6f} Max SRCC: {max_srcc:.6f} Training time: {total_time_str}",
        )
        wandb_runner.finish()
        logging.shutdown()
    else:
        logging.shutdown()
    dist.barrier()
    return


class FeatureLoss(nn.Module):
    """PyTorch version of `Masked Generative Distillation`

    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map.
        name (str): the loss name of the layer
        alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002
        lambda_mgd (float, optional): masked ratio. Defaults to 0.65
    """

    def __init__(self,student_channels,teacher_channels, alpha_mgd=0.00002,lambda_mgd=0.65,):
        super(FeatureLoss, self).__init__()
        self.alpha_mgd = alpha_mgd
        self.lambda_mgd = lambda_mgd

        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0).cuda()
        else:
            self.align = None

        self.generation = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),
            #修改11.28
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1)).cuda()
        # self.pool = nn.AdaptiveMaxPool2d((14,14))
        self.fc = nn.Linear( 196, 625 ,bias=False).cuda()
        self.relu = nn.ReLU(inplace=True)
    def forward(self,preds_S, preds_T):
        """Forward function.
        Args:
            preds_S(Tensor): Bs*C*H*W, student's feature map [B,196,384]
            preds_T(Tensor): Bs*C*H*W, teacher's feature map [B,625,384]
        """
        # assert preds_S.shape[-2:] == preds_T.shape[-2:]

        loss_mse = torch.nn.MSELoss(reduction='sum')
        # torch.Size([12, 196, 384])
        preds_S = self.fc(preds_S.transpose(1, 2)).cuda()
        # preds_S = self.relu(preds_S)
        B, C, new_HW = preds_S.shape #[12 ,384, 625]
        preds_S = preds_S.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))#[B,384,25,25]
        B_t, new_HW_t, C_t = preds_T.shape
        preds_T = preds_T.transpose(1, 2).reshape(B_t, C_t, int(np.sqrt(new_HW_t)), int(np.sqrt(new_HW_t)))#[B,384,25,25]
        # N, s_channels, H_s, W_s = preds_S.shape
        N, t_channels, H_t, W_t = preds_T.shape
        # if H_s != H_t:
        #     preds_T = self.pool(preds_T)
        if self.align is not None:
            preds_S = self.align(preds_S)
        device = preds_S.device
        mat = torch.rand((N, 1, H_t, W_t)).to(device) #[12 ,384, 625]

        mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)  # 得到mask图
        masked_fea = torch.mul(preds_S, mat)  # 得到mask后的特征图

        new_fea =self.generation(masked_fea)  # 生成特征图
        loss=loss_mse(new_fea, preds_T) / N
        dis_loss = loss * self.alpha_mgd

        return dis_loss

@torch.no_grad()
def validate(
        config, data_loader, s_model, epoch, val_len, tensorboard=None, wandb_runner=None
):
    criterion = torch.nn.SmoothL1Loss()
    s_model.eval()
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    temp_pred_scores = []
    temp_gt_scores = []
    end = time.time()
    for idx, (dis_image, _,_, targets_val) in enumerate(data_loader):
        dis_image = dis_image.cuda(non_blocking=True)
        targets_val = targets_val.cuda(non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            # output, _ = model(dis_image)
            output0, output1 , output2 , _ , _ = s_model(dis_image)
        # output = (output0+output1+output2) / 3
        targets_val.unsqueeze_(dim=-1)
        temp_pred_scores.append(output0.view(-1))
        temp_gt_scores.append(targets_val.view(-1))
        # measure accuracy and record loss
        loss = criterion(output0, targets_val)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), targets_val.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            if config.TENSORBOARD == True and tensorboard != None:
                tensorboard.add_scalar(
                    "Validating Loss",
                    loss_meter.val,
                    epoch * len(data_loader) + idx,
                )
            if wandb_runner:
                wandb_runner.log(
                    {
                        "Validating Loss": loss_meter.val,
                        "Validate Batch": epoch * len(data_loader) + idx,
                    }
                )
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f"Test: [{idx}/{len(data_loader)}]\t"
                f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                f"Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t"
                f"Mem {memory_used:.0f}MB"
            )
    pred_scores = torch.cat(temp_pred_scores)
    gt_scores = torch.cat(temp_gt_scores)
    # For distributed parallel, collect all data and then run metrics.
    if torch.distributed.is_initialized():
        preds_gather_list = [
            torch.zeros_like(pred_scores) for _ in range(dist.get_world_size())
        ]
        torch.distributed.all_gather(preds_gather_list, pred_scores)
        gather_preds = torch.cat(preds_gather_list, dim=0)[:val_len]
        gather_preds = (
            (gather_preds.view(-1, config.DATA.PATCH_NUM)).mean(dim=-1)
        ).squeeze()
        grotruth_gather_list = [
            torch.zeros_like(gt_scores) for _ in range(dist.get_world_size())
        ]
        torch.distributed.all_gather(grotruth_gather_list, gt_scores)
        gather_grotruth = torch.cat(grotruth_gather_list, dim=0)[:val_len]
        gather_grotruth = (
            (gather_grotruth.view(-1, config.DATA.PATCH_NUM)).mean(dim=-1)
        ).squeeze()
        # na = gather_preds.cpu().numpy()
        # df = pd.DataFrame(na)
        # df.to_csv(header=None, path_or_buf="gather_preds.csv", index=None)
        # nb = gather_grotruth.cpu().numpy()
        # dfb = pd.DataFrame(nb)
        # dfb.to_csv(header=None, path_or_buf="gather_grotruth.csv", index=None)
        final_preds = gather_preds.cpu().tolist()
        final_grotruth = gather_grotruth.cpu().tolist()

    test_srcc, _ = stats.spearmanr(final_preds, final_grotruth)
    test_plcc, _ = stats.pearsonr(final_preds, final_grotruth)
    logger.info(f" * SRCC@ {test_srcc:.6f} PLCC@ {test_plcc:.6f}")
    return test_srcc, test_plcc, loss_meter.avg


@torch.no_grad()
def throughput(data_loader, s_model, logger):
    s_model.eval()

    for idx, (dis_image, _ , _) in enumerate(data_loader):
        dis_image = dis_image.cuda(non_blocking=True)
        batch_size = dis_image.shape[0]
        for i in range(50):
            s_model(dis_image)
        torch.cuda.synchronize()
        logger.info(f"throughput averaged with 30 times")
        tic1 = time.time()
        for i in range(30):
            s_model(dis_image)
        torch.cuda.synchronize()
        tic2 = time.time()
        logger.info(
            f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}"
        )
        return


if __name__ == "__main__":
    args, config = parse_option()

    if config.AMP_OPT_LEVEL:
        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")

    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    torch.cuda.set_device(config.LOCAL_RANK)
    torch.distributed.init_process_group(
        backend="nccl", init_method="env://", world_size=world_size, rank=rank
    )
    dist.barrier()

    if args.repeat:
        assert args.rnum > 1
        num = args.rnum
    else:
        num = 1
    base_path = config.OUTPUT
    logger = logging.getLogger(name=f"{config.MODEL.NAME}")
    for i in range(num):

        if num > 1:
            config.defrost()
            config.OUTPUT = os.path.join(base_path, str(i))
            config.EXP_INDEX = i + 1
            config.SET.TRAIN_INDEX = None
            config.SET.TEST_INDEX = None
            config.freeze()
        random.seed(None)

        os.makedirs(config.OUTPUT, exist_ok=True)

        filename = "sel_num.data"
        if dist.get_rank() == 0:
            if not os.path.exists(sel_path := os.path.join(config.OUTPUT, filename)):
                sel_num = list(range(0, config.SET.COUNT))
                random.shuffle(sel_num)
                with open(os.path.join(config.OUTPUT, filename), "wb") as f:
                    pickle.dump(sel_num, f)
                del sel_num
        dist.barrier()

        with open(os.path.join(config.OUTPUT, filename), "rb") as f:
            sel_num = pickle.load(f)

        config.defrost()
        config.SET.TRAIN_INDEX = sel_num[0: int(round(0.8 * len(sel_num)))]
        config.SET.TEST_INDEX = sel_num[int(round(0.8 * len(sel_num))): len(sel_num)]
        config.freeze()

        seed = config.SEED + dist.get_rank()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        cudnn.benchmark = True

        create_logger(
            logger,
            output_dir=config.OUTPUT,
            dist_rank=dist.get_rank(),
            name=f"{config.MODEL.NAME}",
        )

        if dist.get_rank() == 0:
            path = os.path.join(config.OUTPUT, "config.json")
            with open(path, "w") as f:
                f.write(config.dump())
            logger.info(f"Full config saved to {path}")

        # print config
        logger.info(config.dump())
        logger.info(json.dumps(vars(args)))

        main(config)
        logger.handlers.clear()
