# Common
import os
import os.path as osp
import datetime
import argparse
import warnings
import socket
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist

import wandb

# network
from network.discriminator_Out import Discriminator_out, Discriminator_out2, Discriminator_out_torch, Discriminator_out_enc_torch
from network.minkUnet import MinkUNet34

from utils import common as com
from utils.logger_FADA import setup_logger

# config file
from configs.config_base import cfg_from_yaml_file
from easydict import EasyDict

warnings.filterwarnings("ignore")
from git import Repo 

import re

import pdb

from torch.nn.parallel import DistributedDataParallel as DDP
import MinkowskiEngine as ME

# torch.autograd.set_detect_anomaly(True)

# a)T T 3.7   b) F F 14  c) T F 13   d) F T 1.8
# torch.backends.cudnn.enabled = False
# torch.backends.cudnn.benchmark = True


def single2ddp_state_dict(model_dict):
    print("ADD MODULE!!!")
    print("ADD MODULE!!!")
    ddp_state_dict = {}
    for key, value in model_dict.items():
        new_key = f"module.{key}"  # "conv0p1s1.kernel" -> "module.conv0p1s1.kernel"
        ddp_state_dict[new_key] = value
    return ddp_state_dict

def ddp2single_state_dict(model_dict):
    print("DELETE MODULE!!!")
    print("DELETE MODULE!!!")
    single_state_dict = {}
    for key, value in model_dict.items():
        if key.startswith("module."):
            new_key = key[len("module."):]  # "module.conv0p1s1.kernel" -> "conv0p1s1.kernel"
        else:
            new_key = key
        single_state_dict[new_key] = value
    return single_state_dict

def change_Config_DEBUG(cfg):
    cfg.TRAIN.T_VAL_ITER = cfg.DEBUG.T_VAL_ITER
    cfg.TRAIN.S_VAL_ITER = cfg.DEBUG.S_VAL_ITER
    cfg.TRAIN.LOG_PERIOD = cfg.DEBUG.LOG_PERIOD
    cfg.TRAIN.PREHEAT_STEPS = cfg.DEBUG.PREHEAT_STEPS
    cfg.TRAIN.EXP_NAME = cfg.DEBUG.EXP_NAME

    cfg.TGT_LOSS.AUX_LOSS_START_ITER = cfg.DEBUG.AUX_LOSS_START_ITER
    cfg.TGT_LOSS.cal_start_iter = 10

    if cfg.TRAIN.STAGE == "stage_1_PCAN":
        cfg.PROTOTYPE.PROTO_UPDATE_PERIOD = cfg.DEBUG.PROTO_UPDATE_PERIOD
    if cfg.TRAIN.STAGE == "stage_1_PCAN" or cfg.TRAIN.STAGE == "stage_2_SAMLM" :
        cfg.MEAN_TEACHER.T_THRE_ZERO_ITER = cfg.DEBUG.T_THRE_ZERO_ITER

    return cfg

def parse_args():
    parser = argparse.ArgumentParser(description='PGDA training')
    parser.add_argument(
        '--cfg',
        dest='config_file',
        default='configs/SynLiDAR2SemanticKITTI/stage_1_PCAN.yaml',
        metavar='FILE',
        help='path to config file',
        type=str,
    )
    parser.add_argument(
        '--checkpoint',
        default=None,
        type=str,
        help='Path to the checkpoint file for resuming training'
    )
    parser.add_argument(
        '--scaling_rule',
        dest='scaling_rule',
        default='linear', # sqrt
        help='type of scaling rule',
        type=str,
    )
    parser.add_argument(
        '--adaptive_lr_scaling',
        action='store_true',
        help='Enable adaptive learning rate scaling'
    )

    args = parser.parse_args()
    return args


def main():
    com.make_reproducible() # freeze all seeds
    args = parse_args()

    # ------------------- ① DDP 초기화 -------------------
    num_gpus   = torch.cuda.device_count()
    using_ddp  = (num_gpus > 1) and ("LOCAL_RANK" in os.environ)

    if using_ddp:
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        torch.distributed.init_process_group(
            backend     = "nccl",
            init_method = "env://")
        rank        = torch.distributed.get_rank()
        world_size  = torch.distributed.get_world_size()
        device      = torch.device(f"cuda:{local_rank}")
        torch.backends.cuda.matmul.allow_tf32 = True
    else:
        rank, world_size = 0, 1
        device           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    # load the configuration
    cfg = EasyDict()
    cfg.OUTPUT_DIR = './workspace/'
    cfg_from_yaml_file(args.config_file, cfg)

    cfg.TRAIN.config_file = args.config_file
    curPath = os.path.abspath(os.path.dirname(__file__))
    cfg.TRAIN.CURPATH = curPath

    repo = Repo(curPath)
    print(repo.active_branch)

    wb_note = '*Path: ' + str(curPath)  + '      **Git branch: ' + str(repo.active_branch)

    # mkdir for logs and checkpoints
    time_now = datetime.datetime.now().strftime("%m-%d-%H_%M")
    print('NOW IS... : ', time_now)
    
    # Init wandb and logger
    if cfg.TRAIN.DEBUG:
        ## WANDB setting
        os.environ['WANDB_MODE'] = 'dryrun'
        cfg = change_Config_DEBUG(cfg)

    cfg.TRAIN.MODEL_DIR = osp.join(cfg.OUTPUT_DIR, cfg.TRAIN.PROJECT_NAME, 'checkpoints', cfg.TRAIN.EXP_NAME)
    os.makedirs(cfg.TRAIN.MODEL_DIR, exist_ok=True)
    cfg.TRAIN.LOG_DIR = osp.join(cfg.OUTPUT_DIR, cfg.TRAIN.PROJECT_NAME, 'logs', cfg.TRAIN.EXP_NAME)
    os.makedirs(cfg.TRAIN.LOG_DIR, exist_ok=True)
    cfg.TRAIN.TB_DIR = osp.join(cfg.OUTPUT_DIR, cfg.TRAIN.PROJECT_NAME, 'tb_dirs', cfg.TRAIN.EXP_NAME)
    if rank == 0:
        print("checkpoint directory: ", cfg.TRAIN.MODEL_DIR)
        print("checkpoint directory: ", cfg.TRAIN.MODEL_DIR)
        print("checkpoint directory: ", cfg.TRAIN.MODEL_DIR)

    if rank == 0:
        hostname            = socket.gethostname()
        cfg.TRAIN.HOSTNAME  = hostname
        cfg.TRAIN.WANDB_ID  = str(wandb.util.generate_id())
        wandb.init(name=cfg.TRAIN.EXP_NAME,
                   notes=wb_note,
                   project=cfg.TRAIN.PROJECT_NAME,
                   entity='anonymous_authors',
                   id=cfg.TRAIN.WANDB_ID,
                   settings=wandb.Settings(start_method="thread"))
        wandb.config.update(cfg, allow_val_change=True)
    else:
        os.environ["WANDB_MODE"] = "disabled"
    
    if num_gpus > 1:
        print(f"DistributedDataParallel! DistributedDataParallel! GPU NUM: {num_gpus}")
        print(f"DistributedDataParallel! DistributedDataParallel! GPU NUM: {num_gpus}")
    else:
        print("Single GPU! Single GPU! Single GPU!")
        print("Single GPU! Single GPU! Single GPU!")

    if rank == 0: print(cfg)

    if num_gpus > 1:
        # Init logger and tensorboard
        # ------ rank‑0 만 로그/시각화 -------
        if rank == 0:
            logger    = setup_logger("Trainer", cfg)
            logger.info('Experiment note:\n%s\n', wb_note)
            tf_writer = SummaryWriter(cfg.TRAIN.TB_DIR)
        else:
            logger, tf_writer = None, None
        # ------------------------------------
    else:
        # Init logger and tensorboard
        logger = setup_logger("Trainer", cfg)  # Init Logging
        logger.info('this Experiment is: \n %s \n \n ' % wb_note)
        tf_writer = SummaryWriter(cfg.TRAIN.TB_DIR)
   
    # Init logger and tensorboard
    logger = setup_logger("Trainer", cfg)  # Init Logging
    logger.info('this Experiment is: \n %s \n \n ' % wb_note)
    tf_writer = SummaryWriter(cfg.TRAIN.TB_DIR)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # print("Host: {}, GPU: {}, wandb_ID: {}".format(hostname, cfg.TRAIN.GPU_ID, cfg.TRAIN.WANDB_ID))
    if rank == 0:
        print("Host: {}, wandb_ID: {}".format(hostname, cfg.TRAIN.WANDB_ID))

    # init network G and D
    net = MinkUNet34(cfg.MODEL_G.IN_CHANNELS, cfg.MODEL_G.NUM_CLASSES, cfg.TGT_LOSS.CAL_out).to(device)
    G_optim = optim.Adam(net.parameters(), lr=cfg.OPTIMIZER.LEARNING_RATE_G)
    # init old-network. This model is utilized to generate pseudo-label.
    old_net = MinkUNet34(cfg.MODEL_G.IN_CHANNELS, cfg.MODEL_G.NUM_CLASSES, cfg.TGT_LOSS.CAL_out).to(device)
    
    # =================== OOM-safe batch size ===================
    scale_factor = cfg.DATALOADER.SCALE_FACTOR
    if torch.cuda.is_available() and args.adaptive_lr_scaling:
        props = torch.cuda.get_device_properties(0)
        total_memory_gb = props.total_memory / (1024 ** 3)
        if total_memory_gb >= 40:
            print(f"GPU memory: {total_memory_gb:.2f}GB detected. Increasing batch size by factor of 2.")
            print(f"GPU memory: {total_memory_gb:.2f}GB detected. Increasing batch size by factor of 2.")
            print(f"GPU memory: {total_memory_gb:.2f}GB detected. Increasing batch size by factor of 2.")
            scale_factor = 2 # scale_factor *= 2
        else:
            print(f"Just same as before...")
            print(f"Just same as before...")
            print(f"Just same as before...")
    if scale_factor != 1:
        scale_factor = num_gpus * scale_factor
        if args.scaling_rule == 'linear':
            scale_factor_lr = scale_factor
            print("linear scaling rule")
        elif args.scaling_rule == 'sqrt':
            import math
            scale_factor_lr = math.sqrt(scale_factor)
            print("sqrt scaling rule")
        else:
            raise ValueError('Unknown scaling rule: {}'.format(args.scaling_rule))
        cfg.DATALOADER.TRA_BATCH_SIZE = int(cfg.DATALOADER.TRA_BATCH_SIZE * scale_factor)
        cfg.DATALOADER.NUM_WORKERS = int(cfg.DATALOADER.NUM_WORKERS * scale_factor)
        cfg.TRAIN.MAX_ITERS = int(cfg.TRAIN.MAX_ITERS // scale_factor)
        cfg.TRAIN.MAX_EPOCHS = int(cfg.TRAIN.MAX_EPOCHS // scale_factor)
        cfg.TRAIN.PREHEAT_STEPS = int(cfg.TRAIN.PREHEAT_STEPS // scale_factor)
        cfg.TRAIN.T_VAL_ITER = int(cfg.TRAIN.T_VAL_ITER // scale_factor)
        cfg.TRAIN.S_VAL_ITER = int(cfg.TRAIN.S_VAL_ITER // scale_factor)
        cfg.OPTIMIZER.LEARNING_RATE_G = cfg.OPTIMIZER.LEARNING_RATE_G * scale_factor_lr
        if hasattr(cfg.OPTIMIZER, 'LEARNING_RATE_D'):
            cfg.OPTIMIZER.LEARNING_RATE_D = cfg.OPTIMIZER.LEARNING_RATE_D * scale_factor_lr
        print("MAX_ITERS is now: ", cfg.TRAIN.MAX_ITERS)
    ############################## learning rate scaling rule ##############################


    ###################### CHECKPOINT LOADING ######################
    if args.checkpoint is not None: ## CHECKPOINT
        print('Resuming training from checkpoint: {}'.format(args.checkpoint))
        checkpoint = torch.load(args.checkpoint, map_location=device)
        state_dict = checkpoint['model_state_dict']
        # DDP ↔ single 변환 자동
        needs_strip = any(k.startswith('module.') for k in state_dict)
        if not using_ddp and needs_strip:
            state_dict = {k[len("module."):]: v for k, v in state_dict.items()}
        if using_ddp and not needs_strip:
            state_dict = {f"module.{k}": v for k, v in state_dict.items()}
        net.load_state_dict(state_dict)

        G_optim.load_state_dict(checkpoint['G_optim_state_dict'])
        if old_net is not None:
            old_net.load_state_dict(state_dict)
        start_iter = checkpoint['cur_iter']
        print(f'Successfully resumed from iteration {start_iter}')

    else: ## PRETRAINED FINETUNE
        start_iter = 0
        if cfg.TRAIN.PRETRAINPATH is not None:
            print('Start loading pretrained model')
            try:
                checkpoint = torch.load(cfg.TRAIN.PRETRAINPATH, map_location='cpu')
            except:
                if cfg.TRAIN.STAGE == 'stage_1_AE' or cfg.TRAIN.STAGE == 'stage_2_oursLM' \
                    or cfg.TRAIN.STAGE == 'stage_2_oursSACLM':
                    checkpoint = torch.load(cfg.TRAIN.PRETRAINPATH, map_location='cpu', strict=False)
                else:
                    raise("Unknown stage: {}".format(cfg.TRAIN.STAGE))
            if rank == 0 :
                print('*** using preTrain model: %s ***' % cfg.TRAIN.PRETRAINPATH)
            pretrained_dict = checkpoint['model_state_dict']
            # Update parameters for G now
            model_dict = net.state_dict()
            model_dict.update(pretrained_dict)
            try:
                net.module.load_state_dict(model_dict)
            except:
                try:
                    net.load_state_dict(model_dict)
                except:
                    net.load_state_dict(model_dict, strict=False)
            # Update parameters for G old
            if old_net is not None:
                try:
                    model_dict = old_net.module.state_dict()
                    model_dict.update(pretrained_dict)
                    old_net.module.load_state_dict(model_dict)
                except:
                    model_dict = old_net.state_dict()
                    model_dict.update(pretrained_dict)
                    try:
                        old_net.load_state_dict(model_dict)
                    except:
                        old_net.load_state_dict(model_dict, strict=False)
            print('finish update pretrained parameters')
        else:
            ## net-old_net same start
            try:
                model_dict = net.module.state_dict()
                model_dict_old_net = old_net.module.state_dict()
                model_dict_old_net.update(model_dict)
                old_net.load_state_dict(model_dict_old_net)
            except:
                model_dict = net.state_dict()
                model_dict_old_net = old_net.state_dict()
                model_dict_old_net.update(model_dict)
                old_net.load_state_dict(model_dict_old_net)
    ###################### CHECKPOINT LOADING ######################


    # --------- ② Model DDP wrapping  ------------ #
    def convert_to_ddp(m, requires_grad=True):
        """
          * SyncBN → SyncBN for DDP
          * EMA / teacher 모델은 gradients 필요 없으므로 DDP 래핑 안 함
        """
        if not requires_grad:
            return m.to(device)                # 그냥 move만
        m = torch.nn.SyncBatchNorm.convert_sync_batchnorm(m)
        return DDP(
            m.to(device),
            device_ids=[device.index],
            output_device=device.index,
            find_unused_parameters=True)
        # net._set_static_graph()
    # --------- ② Model DDP wrapping  ------------ #


    if using_ddp:
        net     = convert_to_ddp(net, requires_grad=True)
        if old_net is not None:
            old_net = convert_to_ddp(old_net, requires_grad=False)
        if 'discriminator' in locals():
            discriminator = convert_to_ddp(discriminator, requires_grad=True)
        if 'discriminator_2' in locals():
            discriminator = convert_to_ddp(discriminator, requires_grad=True)
        if 'discriminator_enc' in locals():
            discriminator = convert_to_ddp(discriminator, requires_grad=True)
    else:
        net = net.to(device)
        if old_net is not None:
            old_net = old_net.to(device)

    ###################### MODEL SELECTION ######################
    if cfg.TRAIN.STAGE == "stage1_ours":
        torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = True

        discriminator = Discriminator_out_torch(cfg).to(device)
        D_optim = optim.Adam(discriminator.parameters(), lr=cfg.OPTIMIZER.LEARNING_RATE_D)

        from trainer_stage1_ours import stage1_ours_Trainer
        trainer = stage1_ours_Trainer(cfg,
                                    net, old_net, discriminator,
                                    G_optim, D_optim,
                                    logger, tf_writer, device)
    else:
        raise ValueError('Unknown stage: {}'.format(cfg.TRAIN.STAGE))
    ###################### MODEL SELECTION ######################


    ###################### START ITERATION INITIALIZE ######################
    if args.checkpoint is not None: ## CHECKPOINT
        if 'D_out_model_state_dict' in checkpoint:
            try:
                discriminator.module.load_state_dict(checkpoint['D_out_model_state_dict'])
            except:
                discriminator.load_state_dict(checkpoint['D_out_model_state_dict'])
        if 'D_optim_state_dict' in checkpoint:
            D_optim.load_state_dict(checkpoint['D_optim_state_dict'])

        if 'src_centers_Proto' in checkpoint: 
            trainer.out_class_center.Proto = checkpoint['src_centers_Proto'].to(device)
        if 'src_centers_Amount' in checkpoint:
            trainer.out_class_center.Amount = checkpoint['src_centers_Amount'].to(device)

        # print('start iteration: {}'.format(start_iter))
        trainer.start_iter = start_iter
        trainer.c_iter = start_iter

        proto_ckpt_path = os.path.join(
            os.path.dirname(args.checkpoint),
            f'cp_out_iter_{start_iter - 1}.tar'
        )
        if os.path.exists(proto_ckpt_path):
            print(f"Found prototype checkpoint: {proto_ckpt_path}")
            trainer.out_class_center.load(proto_ckpt_path, device=device)
        else:
            print(f"No prototype checkpoint found at: {proto_ckpt_path}")
        # ===========================================
        print(f'Successfully resumed from iteration {start_iter}')

    else: ## PRETRAINED FINETUNE
        trainer.start_iter = 0
        trainer.c_iter = 0
    ###################### START ITERATION INITIALIZE ######################


    ###################### TRAIN  ######################

    trainer.train()

    print("checkpoint directory: ", cfg.TRAIN.MODEL_DIR)
    print("checkpoint directory: ", cfg.TRAIN.MODEL_DIR)
    print("checkpoint directory: ", cfg.TRAIN.MODEL_DIR)

    if using_ddp:
        dist.barrier()
        dist.destroy_process_group()
        

    ###################### TRAIN  ######################

if __name__ == '__main__':
    main()