import os
import random
import time
import argparse
import datetime
import numpy as np
import subprocess
import utils

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from timm.utils import AverageMeter

from config import get_config
from models import build_model
from data_loader import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from models.frequency_loss import FrequencyLoss
from utils import load_checkpoint_student_teacher, save_checkpoint_student_teacher, get_grad_norm, auto_resume_helper, reduce_tensor


try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None



def parse_option():
    parser = argparse.ArgumentParser('MFM pre-training script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--weight_mfm_loss', default=1, type=int, help="WEIGHT_MFM_LOSS")
    parser.add_argument('--weight_distillation_loss', default=1, type=int, help="WEIGHT_DISTILLATION_LOSS")
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--distillation_version', type=str, default='DistillationHeaderV1', 
                         choices=['DistillationHeaderV1', 'DistillationHeaderV2', 'DistillationHeaderV11', 'DistillationHeaderV21',
                                  'DistillationHeaderV3', 'DistillationHeaderV2CLS', 'DistillationHeaderV31', 'DistillationHeaderV21CLS', 
                                  'DistillationHeaderV3CLS', 'DistillationHeaderV1CLS', 'DistillationHeaderV31CLS', 'DistillationHeaderV11CLS',
                                  'DINOHead'], 
                                  help='version of distillation header')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--dis_loss', default='1-cosin', choices=['1-cosin', '_cosin', 'softmax', '_cosin_dino', 'one_cosin_dino', 'ibot'], help='distillation loss')
    parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag, help="Whether to use batch normalizations in projection head (Default: False)")
    parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
                        help="""Whether or not to weight normalize the last layer of the DINO head.
                        Not normalizing leads to better performance but can make the training unstable.`
                        In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""")
    parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
        the DINO head output. For complex and large datasets large values (like 65k) work well.""")
    # distributed training
    parser.add_argument("--distributed", action="store_true", help="Using distributed")
    parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--port', type=int, default=29500, help='port only works when launcher=="slurm"')
    parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training')

    args = parser.parse_args()
    args.multi_head_studnet_strategy = True

    config = get_config(args)
    

    return args, config




def get_model_info(model, name):
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"{name}: number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"{name}: number of GFLOPs: {flops / 1e9}")
    return model_without_ddp
    

def main(config, log_writer):
    data_loader_train = build_loader(config, logger, is_pretrain=True)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    student_model = build_model(config, is_pretrain=True, is_student=True)
    teacher_model = build_model(config, is_pretrain=True, is_student=False)

    student_model.cuda()
    teacher_model.cuda()

    logger.info('Distillation Header Version')
    logger.info(config.MODEL.DISTILLATION_HEADER_VERSION)

    logger.info("student model is:")
    logger.info(str(student_model))

    logger.info("teacher model is:")
    logger.info(str(teacher_model))

    student_optimizer = build_optimizer(config, student_model, logger, is_pretrain=True)

    criterion_freq = FrequencyLoss(
            loss_gamma=config.MODEL.FREQ_LOSS.LOSS_GAMMA,
            matrix_gamma=config.MODEL.FREQ_LOSS.MATRIX_GAMMA,
            patch_factor=config.MODEL.FREQ_LOSS.PATCH_FACTOR,
            ave_spectrum=config.MODEL.FREQ_LOSS.AVE_SPECTRUM,
            with_matrix=config.MODEL.FREQ_LOSS.WITH_MATRIX,
            log_matrix=config.MODEL.FREQ_LOSS.LOG_MATRIX,
            batch_matrix=config.MODEL.FREQ_LOSS.BATCH_MATRIX).cuda()
    
    criterion_student = nn.CosineSimilarity(dim=1).cuda()
    
    if config.AMP_OPT_LEVEL != "O0":
        student_model, student_optimizer = amp.initialize(student_model, student_optimizer, opt_level=config.AMP_OPT_LEVEL)

    student_model = torch.nn.parallel.DistributedDataParallel(student_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True)
    teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)

    # teacher and student start with the same weights       
    teacher_model.load_state_dict(student_model.module.state_dict(), strict=False)

    student_model_without_ddp = get_model_info(student_model, 'Student')
    teacher_model_without_ddp = get_model_info(teacher_model, 'Teacher')
 
    for p in teacher_model.parameters():
        p.requires_grad = False
    student_lr_scheduler = build_scheduler(config, student_optimizer, len(data_loader_train))
    momentum_schedule = cosine_scheduler(config.TRAIN.MOMENTUM_TEACHER, 1, config.TRAIN.EPOCHS, len(data_loader_train))

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT, logger)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

    if config.MODEL.RESUME:
        load_checkpoint_student_teacher(config, student_model_without_ddp, teacher_model_without_ddp,
                                        student_optimizer, student_lr_scheduler, logger)
        
    logger.info("Start training")
    logger.info(f"config.TRAIN.WEIGHT_MFM_LOSS: {config.TRAIN.WEIGHT_MFM_LOSS},  config.TRAIN.WEIGHT_DISTILLATION_LOSS: {config.TRAIN.WEIGHT_DISTILLATION_LOSS}")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, student_model, teacher_model, data_loader_train, student_optimizer, epoch, 
                        student_lr_scheduler, criterion_freq, criterion_student, momentum_schedule,
                        teacher_model_without_ddp, log_writer)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint_student_teacher(config, epoch, student_model_without_ddp, teacher_model_without_ddp,
                                            0., student_optimizer, student_lr_scheduler, logger)
        if dist.get_rank() == 0:
            if log_writer is not None:
                log_writer.flush()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))


def get_norm(student_optimizer, student_model):
    if config.AMP_OPT_LEVEL != "O0":
        if config.TRAIN.CLIP_GRAD:
            student_grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(student_optimizer), config.TRAIN.CLIP_GRAD)
        else:
            student_grad_norm = get_grad_norm(amp.master_params(student_optimizer))
    else:
        if config.TRAIN.CLIP_GRAD:
            student_grad_norm = torch.nn.utils.clip_grad_norm_(student_model.parameters(), config.TRAIN.CLIP_GRAD)
        else:
            student_grad_norm = get_grad_norm(student_model.parameters())

    return student_grad_norm


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def get_distillation_loss(config, criterion_student, teacher_cls_res, teacher_distillation_res,
                          student_cls_res, student_distillation_res):
    distillation_loss = 0.0

    if config.TRAIN.DISS_LOSS == 'one_cosin_dino':
        distillation_loss = (1 - criterion_student(F.normalize(teacher_distillation_res, p=2, dim=1),
                                                   F.normalize(student_distillation_res, p=2, dim=1)).mean())
        
    if config.TRAIN.DISS_LOSS == '_cosin_dino':
        distillation_loss = (-criterion_student(F.normalize(teacher_distillation_res, p=2, dim=1),
                                                F.normalize(student_distillation_res, p=2, dim=1)).mean())

    if config.TRAIN.DISS_LOSS == '1-cosin':
        distillation_loss = (1 - criterion_student(F.normalize(teacher_cls_res, p=2, dim=1),
                                                   F.normalize(student_distillation_res, p=2, dim=1)).mean())
    
    if config.TRAIN.DISS_LOSS == '_cosin':
        distillation_loss = - criterion_student(F.normalize(teacher_cls_res, p=2, dim=1),
                                                F.normalize(student_distillation_res, p=2, dim=1)).mean()
    
    
    return distillation_loss

def train_one_epoch(config, student_model, teacher_model, data_loader, student_optimizer, epoch,
                    student_lr_scheduler, criterion_freq, criterion_student, momentum_schedule, 
                    teacher_model_without_ddp, log_writer):
    student_model.train()


    student_optimizer.zero_grad()
    

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    mfm_loss_meter = AverageMeter()
    distillation_loss_meter = AverageMeter()
    student_loss_meter = AverageMeter()
    student_norm_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (img, img_lq, mask, _) in enumerate(data_loader):
        img = img.cuda(non_blocking=True)
        if img_lq is not None:
            img_lq = img_lq.cuda(non_blocking=True)
        if mask is not None:
            mask = mask.cuda(non_blocking=True)

        student_input, student_cls_res, student_mfm_res, student_distillation_res = student_model(img, mask)
        teacher_input, teacher_cls_res, teacher_mfm_res, teacher_distillation_res = teacher_model(img, mask)
        # logger.info(f"teacher_distillation_res shape is {teacher_distillation_res.shape}")
        # logger.info(f"distillation_res shape is {distillation_res.shape}")

        mfm_loss = criterion_freq(student_mfm_res, teacher_input).mean()

        # distillation_loss = - criterion_student(F.normalize(teacher_distillation_res, p=2, dim=1),
        #                                       F.normalize(distillation_res, p=2, dim=1)).mean()
        distillation_loss = get_distillation_loss(config, criterion_student, 
                                                  teacher_cls_res, teacher_distillation_res, 
                                                  student_cls_res, student_distillation_res)
        # distillation_loss = (1 - criterion_student(teacher_distillation_res, distillation_res).mean())

        # student_loss = (config.TRAIN.WEIGHT_MFM_LOSS * mfm_loss + 
        #                 config.TRAIN.WEIGHT_DISTILLATION_LOSS * distillation_loss) / (config.TRAIN.WEIGHT_MFM_LOSS + config.TRAIN.WEIGHT_DISTILLATION_LOSS)


        student_loss = (config.TRAIN.WEIGHT_MFM_LOSS * mfm_loss + config.TRAIN.WEIGHT_DISTILLATION_LOSS * distillation_loss) 
        # student and headers update
        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(student_loss, student_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                student_loss.backward()

            student_grad_norm = get_norm(student_optimizer, student_model)

            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                student_optimizer.step()
                student_optimizer.zero_grad()
                student_lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            student_optimizer.zero_grad()

            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(student_loss, student_optimizer) as scaled_loss:
                    scaled_loss.backward()                
            else:
                student_loss.backward()

            student_grad_norm = get_norm(student_optimizer, student_model)
            student_optimizer.step()
            student_lr_scheduler.step_update(epoch * num_steps + idx)

        # EMA update for the teacher
        # common params
            
        with torch.no_grad():
            m = momentum_schedule[idx + num_steps * epoch]  # momentum parameter
            for param_q, param_k in zip(student_model.module.parameters(), teacher_model_without_ddp.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
        # names_q, params_q, names_k, params_k = [], [], [], []
        # for name_q, param_q in student_model.module.named_parameters():
        #     names_q.append(name_q)
        #     params_q.append(param_q)
        # for name_k, param_k in teacher_model.named_parameters():
        #     names_k.append(name_k)
        #     params_k.append(param_k)
        # names_common = list(set(names_q) & set(names_k))
        # params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common]
        # params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common]

        # teacher_params = teacher_model.parameters()
        # with torch.no_grad():
        #     m = momentum_schedule[idx + num_steps * epoch]  # momentum parameter
        #     for param_q, param_k in zip(params_q, params_k):
        #         param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
        
        # for teacher_param, updated_param_k in zip(teacher_params, params_k):
        #     teacher_param.data.copy_(updated_param_k.detach().data)

        #  logging
        torch.cuda.synchronize()
        student_loss_meter.update(student_loss.item(), img.size(0))
        mfm_loss_meter.update(mfm_loss.item(), img.size(0))
        distillation_loss_meter.update(distillation_loss.item(), img.size(0))
        student_norm_meter.update(student_grad_norm)

        batch_time.update(time.time() - end)
        end = time.time()

        student_lr = student_optimizer.param_groups[0]["lr"]
        mfm_loss_value_reduce = reduce_tensor(mfm_loss).item()
        distillation_loss_value_reduce = reduce_tensor(distillation_loss).item()
        student_loss_value_reduce = reduce_tensor(student_loss).item()

        if log_writer is not None and (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((idx / num_steps + epoch) * 1000)
            log_writer.add_scalar('student_train_loss', student_loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('distillation_train_loss', distillation_loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('mfm_train_loss', mfm_loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('student_grad_norm', student_grad_norm, epoch_1000x)
            log_writer.add_scalar('student_lr', student_lr, epoch_1000x)

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))}\t' 
                f'student_lr {student_lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'mfm_loss {mfm_loss_meter.val:.4f} ({mfm_loss_meter.avg:.4f})\t'
                f'distillation_loss {distillation_loss_meter.val:.4f} ({distillation_loss_meter.avg:.4f})\t'
                f'student_loss {student_loss_meter.val:.4f} ({student_loss_meter.avg:.4f})\t'
                f'student_norm {student_norm_meter.val:.4f} ({student_norm_meter.avg:.4f})\t'
                f'momentum_schedule {m}\t'
                f'mem {memory_used:.0f}MB')

    epoch_time = time.time() - start
    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")


if __name__ == '__main__':
    args, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    cuda_version = torch.version.cuda

    # Check for cuDNN version
    cudnn_version = torch.backends.cudnn.version()

    print(f"CUDA Version: {cuda_version}")
    print(f"cuDNN Version: {cudnn_version}")

    ## initialize slurm distributed training environment
    proc_id = int(os.environ['SLURM_PROCID'])
    print(f"proc_id {proc_id}")
    ntasks = int(os.environ['SLURM_NTASKS'])
    print(f"ntasks {ntasks}")
    node_list = os.environ['SLURM_NODELIST']
    print(f"node_list {node_list}")
    num_gpus = torch.cuda.device_count()
    print(f"num_gpus {num_gpus}")
    print(f"torch.cuda.is_available() {torch.cuda.is_available()}")
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    # specify master port
    if args.port is not None:
        os.environ['MASTER_PORT'] = str(args.port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend='nccl')
    world_size = torch.cuda.device_count()
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    print(f"seed {seed}")
    # random.seed(seed)
    cudnn.benchmark = True

    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
    logger.info(f'World Size: {world_size}')
    if dist.get_rank() == 0:
        path = os.path.join(config.OUTPUT, "config.json")
        with open(path, "w") as f:
            f.write(config.dump())
        logger.info(f"Full config saved to {path}")
        log_writer = SummaryWriter(log_dir=config.OUTPUT)
    else:
        log_writer = None

    # print config
    logger.info(config.dump())

    main(config, log_writer)