import os
os.environ["WANDB_API_KEY"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 
import sys
current_directory = os.getcwd()
sys.path.insert(1,current_directory)
import torch
import logging
import warnings
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from srcV2.evaluate import evaluate

from pkgs.openai.clip import load as load_model

from srcV2.data import load as load_data
from srcV2.data import get_clean_train_dataloader, calculate_scores
from srcV2.parser import parse_args
from srcV2.logger import get_logger, set_logger

mp.set_start_method("spawn", force = True)
warnings.filterwarnings("ignore")

def gathered_elements_to_list(gather_elements):
    output = []
    for element in gather_elements:
        output = output + list(element)
    return output

def progressive_removal(options, model, processor, data, epoch):

    path = calculate_scores(options, model, data["train"], epoch)
    gather_path = [None for _ in range(options.num_devices)]
    if options.distributed:
        dist.all_gather_object(gather_path, path)
    
    if not options.master and options.distributed:
        logging.info(f'Device inside barrier 1 {options.device}')
        torch.distributed.barrier()
        logging.info(f'Device outside barrier 1 {options.device}')

    data["train"] = get_clean_train_dataloader(options, processor, path)

    options.train_data = path

    if options.master and options.distributed:
        logging.info(f'Device inside barrier 2 {options.device}')
        torch.distributed.barrier()
        logging.info(f'Device outside barrier 2 {options.device}')

    return options, data

def worker(rank, options, logger):
    options.rank = rank
    options.master = rank == 0
    
    set_logger(rank = rank, logger = logger, distributed = options.distributed)
    if(options.device == "cuda"):
        options.device += ":" + str(options.device_ids[options.rank] if options.distributed else options.device_id)

    logging.info(f"Using {options.device} device")

    if(options.master):
        logging.info("Params:")
        with open(os.path.join(options.log_dir_path, "params.txt"), "w") as file:
            for key in sorted(vars(options)):
                value = getattr(options, key)
                logging.info(f"{key}: {value}")
                file.write(f"{key}: {value}\n")

    if(options.distributed):
        dist.init_process_group(backend = options.distributed_backend, init_method = options.distributed_init_method, world_size = options.num_devices, rank = options.rank)
    
    options.batch_size = options.batch_size // options.num_devices

    start_epoch = 0 
    model, processor = load_model(name = options.model_name, pretrained = options.pretrained)

    if(options.device == "cpu"):
        model.float()
    else:
        torch.cuda.set_device(options.device_ids[options.rank] if options.distributed else options.device_id)
        model.to(options.device)
        if(options.distributed):
            model = DDP(model, device_ids = [options.device_ids[options.rank]])
        
    data = load_data(options, processor)

    evaluate(start_epoch, model, processor, data, options)
    print('test')




if(__name__ == "__main__"):    
    options = parse_args()

    options.log_dir_path = os.path.join(options.logs, options.name)
    options.log_file_path = os.path.join(options.log_dir_path, "output.log")
    
    os.makedirs(options.log_dir_path, exist_ok = True)
    logger, listener = get_logger(options.log_file_path)

    listener.start()

    ngpus = torch.cuda.device_count()
    if(ngpus == 0 or options.device == "cpu"):
        options.device = "cpu"
        options.num_devices = 1
        options.distributed = False
        worker(0, options, logger)
    else:
        if(ngpus == 1 or not options.distributed):
            options.device = "cuda"
            options.num_devices = 1
            options.distributed = False
            worker(0, options, logger)
        else:
            options.device = "cuda"
            if(options.device_ids is None):
                options.device_ids = list(range(ngpus))
                options.num_devices = ngpus
            else:
                options.device_ids = list(map(int, options.device_ids[0].split()))
                options.num_devices = len(options.device_ids)
            options.distributed = True
            os.environ["NCCL_P2P_DISABLE"] = "1"
            mp.spawn(worker, nprocs = options.num_devices, args = (options, logger))
    
    listener.stop()