import os
os.environ["WANDB_API_KEY"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 
import sys
current_directory = os.getcwd()
sys.path.insert(1,current_directory)
import time
import random
import wandb
import torch
import logging
import warnings
import numpy as np
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP

from pkgs.openai.clip import load as load_model

from srcV2.train import train, TAC_train
from srcV2.evaluate import evaluate, evaluate_benign, Finetune
from srcV2.data import load as load_data
from srcV2.data import get_clean_train_dataloader, calculate_scores, get_TAC_train_dataloaer
from srcV2.parser import parse_args
from srcV2.scheduler import cosine_scheduler
from srcV2.logger import get_logger, set_logger

mp.set_start_method("spawn", force = True)
warnings.filterwarnings("ignore")


def gathered_elements_to_list(gather_elements):
    output = []
    for element in gather_elements:
        output = output + list(element)
    return output

def progressive_removal(options, model, processor, data, epoch): #逐步移除训练数据中的异常样本，并更新训练数据集

    if data["train"]:
        path = calculate_scores(options, model, data["train"], epoch)
    elif data["TAC_train"]:
        path = calculate_scores(options, model, data["TAC_train"], epoch)
    else:
        print("Not found data['train'] or data['TAC_train']!")

    gather_path = [None for _ in range(options.num_devices)]
    if options.distributed:
        dist.all_gather_object(gather_path, path)
    
    if not options.master and options.distributed:
        logging.info(f'Device inside barrier 1 {options.device}')
        torch.distributed.barrier()
        logging.info(f'Device outside barrier 1 {options.device}')

    data["train"] = get_clean_train_dataloader(options, processor, path)

    options.train_data = path

    if options.master and options.distributed:
        logging.info(f'Device inside barrier 2 {options.device}')
        torch.distributed.barrier()
        logging.info(f'Device outside barrier 2 {options.device}')

    return options, data

def worker(rank, options, logger):
    options.rank = rank
    options.master = rank == 0
    
    set_logger(rank = rank, logger = logger, distributed = options.distributed)
    if(options.device == "cuda"):
        options.device += ":" + str(options.device_ids[options.rank] if options.distributed else options.device_id)

    logging.info(f"Using {options.device} device")
    print("Using device:",options.device)

    # print("main.py --> Line 73 options.master:",options.master) #True
    if(options.master):
        logging.info("Params:")
        with open(os.path.join(options.log_dir_path, "params.txt"), "w") as file:
            for key in sorted(vars(options)):
                value = getattr(options, key)
                logging.info(f"{key}: {value}")
                file.write(f"{key}: {value}\n")

    if(options.distributed):
        dist.init_process_group(backend = options.distributed_backend, init_method = options.distributed_init_method, world_size = options.num_devices, rank = options.rank)
    
    options.batch_size = options.batch_size // options.num_devices

    model, processor = load_model(name = options.model_name, pretrained = options.pretrained,freeze_layers=options.freeze_layers)

    if(options.device == "cpu"):
        model.float()
    else:
        torch.cuda.set_device(options.device_ids[options.rank] if options.distributed else options.device_id)
        model.to(options.device)
        if(options.distributed):
            model = DDP(model, device_ids = [options.device_ids[options.rank]],broadcast_buffers=False,find_unused_parameters=True)
        
    data = load_data(options, processor)
    print("TEST 1_LoadData DONE!")

    optimizer = None
    scheduler = None

    if (data["train"] is not None):
        weight_decay_parameters = []
        no_weight_decay_parameters = []

        for name, parameter in model.named_parameters():
            if(all(key not in name for key in ["bn", "ln", "bias", "logit_scale"]) and parameter.requires_grad):
                weight_decay_parameters.append(parameter)

            if(any(key in name for key in ["bn", "ln", "bias", "logit_scale"]) and parameter.requires_grad):
                no_weight_decay_parameters.append(parameter)

        optimizer = optim.AdamW([{"params": no_weight_decay_parameters, "weight_decay": 0}, {"params": weight_decay_parameters, "weight_decay": options.weight_decay}], lr = options.lr, betas = (options.beta1, options.beta2), eps = options.eps)
        scheduler = cosine_scheduler(optimizer, options.lr, options.num_warmup_steps, data["train"].num_batches * options.epochs)


    if (data["TAC_train"] is not None):
        weight_decay_parameters = []
        no_weight_decay_parameters = []

        for name, parameter in model.named_parameters():
            if (all(key not in name for key in ["bn", "ln", "bias", "logit_scale"]) and parameter.requires_grad):
                weight_decay_parameters.append(parameter)

            if (any(key in name for key in ["bn", "ln", "bias", "logit_scale"]) and parameter.requires_grad):
                no_weight_decay_parameters.append(parameter)

        # for name, param in model.named_parameters():
        #     print(name)
        # if options.freeze_layers:
        #     for name, param in model.named_parameters():
        #         if any(freeze_layer in name for freeze_layer in options.freeze_layers):
        #             param.requires_grad = False

        optimizer = optim.AdamW([{"params": no_weight_decay_parameters, "weight_decay": 0},
                                 {"params": weight_decay_parameters, "weight_decay": options.weight_decay}],
                                lr=options.lr, betas=(options.beta1, options.beta2), eps=options.eps)
        scheduler = cosine_scheduler(optimizer, options.lr, options.num_warmup_steps,
                                     data["train"].num_batches * options.epochs)

        optimizer_TAC = optim.AdamW([{"params": no_weight_decay_parameters, "weight_decay": 0},
                                 {"params": weight_decay_parameters, "weight_decay": options.weight_decay_TAC}],
                                lr=options.lr_TAC, betas=(options.beta1_TAC, options.beta2_TAC), eps=options.eps_TAC)
        scheduler_TAC = cosine_scheduler(optimizer_TAC, options.lr_TAC, options.num_warmup_steps_TAC,
                                     data["TAC_train"].num_batches * options.epochs)

    start_epoch = 0
    if(options.checkpoint is not None):
        if(os.path.isfile(options.checkpoint)):
            checkpoint  = torch.load(options.checkpoint, map_location = options.device)
            if options.complete_finetune or 'epoch' not in checkpoint:
                start_epoch = 0 
            # start_epoch = 0 if options.complete_finetune else checkpoint['epoch'] 
            state_dict  = checkpoint["state_dict"]
            if(not options.distributed and next(iter(state_dict.items()))[0].startswith("module")):
                state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
            # hack to load a non-distributed checkpoint for distributed training
            if (options.distributed and not next(iter(state_dict.items()))[0].startswith("module")):
                state_dict = {"module."+key: value for key, value in state_dict.items()}
            if(options.checkpoint_finetune):
                finetuned_checkpoint = torch.load(options.checkpoint_finetune, map_location = options.device)
                finetuned_state_dict = finetuned_checkpoint["state_dict"]
                for key in state_dict:
                    if 'visual' in key:
                        ft_key = name.replace("module.", "model.") if "module" in key else f'model.{key}'
                        state_dict[key] = finetuned_state_dict[ft_key]
                print('Loaded Visual Backbone from Finetuned Model')
            model.load_state_dict(state_dict)
            if(optimizer is not None): optimizer.load_state_dict(checkpoint["optimizer"])
            logging.info(f"Loaded checkpoint '{options.checkpoint}' (start epoch {checkpoint['epoch']})")
        else:
            logging.info(f"No checkpoint found at {options.checkpoint}")

    cudnn.benchmark = True
    cudnn.deterministic = False

    if(options.wandb and options.master):
        logging.debug("Starting wandb")
        wandb.init(project = "clip-defense", notes = options.notes, tags = [], config = vars(options), entity = 'mint-adobe')
        wandb.run.name = options.name
        wandb.save(os.path.join(options.log_dir_path, "params.txt"))

    evaluate(start_epoch, model, processor, data, options)

    evaluate_benign(start_epoch, model, processor, data, options)

    if ((data["train"] is not None) or (data["TAC_train"] is not None)):
        options.checkpoints_dir_path = os.path.join(options.log_dir_path, "checkpoints")
        os.makedirs(options.checkpoints_dir_path, exist_ok = True)

        scaler = GradScaler()
        #GradScaler 是 PyTorch 中用于梯度缩放的工具，在混合精度训练中非常有用。它可以在使用低精度（例如，FP16）算术时，防止梯度出现下溢或上溢。GradScaler 可以与 autocast 上下文管理器一起使用，以在反向传播过程中自动缩放梯度

        best_loss = np.inf

        if(options.progressive): #逐步移除训练数据中的异常样本，并更新训练数据集
            options.progressive_epochs = list(map(int, options.progressive_epochs))
            if (start_epoch in options.progressive_epochs):
                options, data = progressive_removal(options, model, processor, data, start_epoch)

        for epoch in range(start_epoch + 1, options.epochs + 1):
            if(options.master): 
                logging.info(f"Starting Epoch {epoch}")

            start = time.time()
            if (data["TAC_train"] is not None):
                logging.info("TAC_training.....")
                train(epoch, model, data, optimizer, scheduler, scaler, options)
                options.TAC_neg_mode = random.choice(['entity','attribute','relation'])
                logging.info(f"Epoch: {epoch}, TAC_neg_mode: {options.TAC_neg_mode}")
                print("TAC neg mode:", options.TAC_neg_mode)
                data["TAC_train"] = get_TAC_train_dataloaer(options, processor)
                TAC_train(epoch, model, data, optimizer_TAC, scheduler_TAC, scaler, options)
            else:
                logging.info("Training.....")
                train(epoch, model, data, optimizer, scheduler, scaler, options)

            end = time.time()

            if(options.master): 
                logging.info(f"Finished Epoch {epoch}, Time Taken: {end - start:.3f}")

            metrics = evaluate(epoch, model, processor, data, options)

            ACC_metrics = evaluate_benign(epoch, model, processor, data, options)

            if(options.master):
                checkpoint = {"epoch": epoch, "name": options.name, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
                if(options.complete_finetune and options.save_final):
                    torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch.pt"))
                # else:
                #     torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch_{epoch}.pt"))
                else:
                    torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch_now.pt"))
                if("loss" in metrics):
                    if(metrics["loss"] < best_loss):
                        best_loss = metrics["loss"]
                        torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch.best.pt"))
            
            if(options.progressive):
                if epoch in options.progressive_epochs:
                    options, data = progressive_removal(options, model, processor, data, epoch)
            
                if epoch == options.stop_epoch:
                    return

    if(options.distributed):
        dist.destroy_process_group()

    if(options.wandb and options.master):
        wandb.finish()

if(__name__ == "__main__"):    
    options = parse_args()

    options.log_dir_path = os.path.join(options.logs, options.name)
    options.log_file_path = os.path.join(options.log_dir_path, "output.log")
    
    os.makedirs(options.log_dir_path, exist_ok = True)
    logger, listener = get_logger(options.log_file_path)

    listener.start() #获取一个日志记录器（logger）和一个监听器（listener），然后启动监听器

    ngpus = torch.cuda.device_count()
    if(ngpus == 0 or options.device == "cpu"):
        options.device = "cpu"
        options.num_devices = 1
        options.distributed = False
        worker(0, options, logger)
    else:
        if(ngpus == 1 or not options.distributed):
            options.device = "cuda"
            options.num_devices = 1
            options.distributed = False
            worker(0, options, logger)
        else:
            options.device = "cuda"
            if(options.device_ids is None):
                options.device_ids = list(range(ngpus))
                options.num_devices = ngpus
            else:
                options.device_ids = list(map(int, options.device_ids[0].split()))
                options.num_devices = len(options.device_ids)
            options.distributed = True
            os.environ["NCCL_P2P_DISABLE"] = "1"
            mp.spawn(worker, nprocs = options.num_devices, args = (options, logger))
    
    listener.stop()