import os
import argparse
import logging
import torch
import wandb
import utils.config as config
from .logger import get_logger, set_logger
from .task_vectors import TaskVector
from .evaluate import evaluate
from .data import get_eval_test_dataloader
from pkgs.openai.clip import load as load_model
from src.eval_retrieval import evaluate as eval_retrieval
from src.eval_retrieval import build_dataset, image_captions_collate_fn
from torch.utils.data import DataLoader, DistributedSampler


def eval_zt(options, logger):
    
    # Set the logger
    set_logger(rank = 0, logger = logger, distributed = options.distributed)
    logging.getLogger().setLevel(logging.INFO)
    logging.info("Starting evaluation for Pretrained Models.")
    logging.info(f"Options: {options}")

    # if(options.wandb and options.master):
    #     logging.debug("Starting wandb")
    #     wandb.init(project = ".", tags = [], config = vars(options), entity = '.')
    #     wandb.run.name = options.name
    #     wandb.save(os.path.join(options.log_dir_path, "params.txt"))


    eval_epoch = 0
    # Load the processor
    model, processor = load_model(name = options.model_name, pretrained = True)
    model.to(options.device)

    # Load the data
    data = {}
    data["validation"] = None
    data["eval_train"] = None
    data["eval_test"] = get_eval_test_dataloader(options, processor)
    

    # Evaluate the model
    eval_mode = "ASR" if options.asr else "CA"
    base_name = options.name 

    # Evaluate the model
    model.eval()
    options.name = base_name  + f"_{eval_mode}"
    if options.retrieval:
        dataset = build_dataset("mscoco_captions", root = "./data/COCO", split = "val", transform = processor.process_image)
        
        # sampler = DistributedSampler(dataset) if(options.distributed) else None

        dataloader = DataLoader(
            dataset, 
            batch_size = options.batch_size, 
            shuffle = False, 
            num_workers = options.num_workers, 
            # pin_memory = True, 
            collate_fn = image_captions_collate_fn)
        
        # new_state_dict = task_vector.apply_to(pretrained_state_dict=eval_checkpoint["state_dict"], scaling_coef=options.alpha)
        # model.load_state_dict(new_state_dict)
        model.eval()
        for idx in range(options.num_evaluations):
            options.name = base_name + f"_{eval_mode}" + f"_eval_num_{idx}"
            metrics = eval_retrieval(model, dataloader, processor.process_text, options.device, amp = False, recall_k_list = [5])
            print(metrics)
            options.name = base_name
            logging.info(f"Metrics for {base_name + '_' + eval_mode + '_num_eval_'+ str(idx)} : {metrics}")
        return

    metrics = evaluate(eval_epoch, model, processor, data, options)
    print(metrics)
    options.name = base_name
    logging.info(f"Metrics for {base_name + '_' + eval_mode} : {metrics}")
    
    if options.labelDiscovery:
        # save metrics to the logginf folder as a .pt file
        metrics_path = os.path.join(options.log_dir_path, "metrics.pt")
        torch.save(metrics, metrics_path)
        logging.info(f"Metrics saved to {metrics_path}")

    return

if __name__ == "__main__":
    # Define the argument parser
    parser = argparse.ArgumentParser(description='Evaluate zeroshot accuracies for a model.')

    parser.add_argument("--name", type=str, default="unlearn_default", help="Name of the experiment.")
    parser.add_argument("--model_name", type=str, default="RN50", choices=["RN50", "RN101", "RN50x4", "ViT-B/32", "ViT-L/14"], help="Name of the model architecture.")
    parser.add_argument("--logs", type = str, default = os.path.join(config.root, "logs/"), help = "Logs directory path")
    parser.add_argument("--eval_data_type", type=str, default="ImageNet1K", help="Type of evaluation dataset.")
    parser.add_argument("--eval_test_data_dir", type=str, required=True, help="Directory of the test dataset.")
    parser.add_argument("--device", type = str, default = None, choices = ["cpu", "gpu"], help = "Specify device type to use (default: gpu > cpu)")
    parser.add_argument("--device_id", type = int, default = 0, help = "Specify device id if using single gpu")
    parser.add_argument("--distributed", action = "store_true", default = False, help = "Use multiple gpus if available")
    parser.add_argument("--num_workers", type = int, default = 8, help = "Number of workers per gpu")
    parser.add_argument("--batch_size", type = int, default = 1024, help = "Batch size")
    parser.add_argument("--patch_type", default = None, type = str, help = "patch type of backdoor")
    parser.add_argument("--patch_location", default = None, type = str, help = "patch location of backdoor")
    parser.add_argument("--patch_size", default = 16, type = int, help = "patch size of backdoor")
    parser.add_argument("--patch_path", default = None, type = str, help = "path to patch")
    parser.add_argument("--add_backdoor", default = False, action = "store_true", help = "add backdoor or not")
    parser.add_argument("--backdoor_sufi", action = "store_true", default = False, help = "backdoor sufi")
    parser.add_argument("--asr", default = False, action = "store_true", help = "Calculate Attack Success Rate (ASR)")
    parser.add_argument("--csv_path", type = str, default = None, help = "path to where you want to save the csv file")
    parser.add_argument("wandb", action = "store_true", default = False, help = "Use wandb for logging")
    parser.add_argument("--labelDiscovery", default = False, action = "store_true", help = "Evaluate label discovery")
    parser.add_argument("--target_class", type = int, default=954, help = "Target class for unlearning")
    parser.add_argument("--retrieval", action = "store_true", default = False, help = "Use retrieval evaluation")
    parser.add_argument("--num_evaluations", type = int, default = 1, help = "Number of evaluations to perform")


    # Parse arguments
    options = parser.parse_args()

    options.log_dir_path = os.path.join(options.logs, options.name)
    options.log_file_path = os.path.join(options.log_dir_path, "output.log")
    options.device = "cuda" if torch.cuda.is_available() else "cpu"
    rank = 0
    options.master = rank == 0
    options.wandb = False
    os.makedirs(options.log_dir_path, exist_ok = True)
    logger, listener = get_logger(options.log_file_path)

    listener.start()
    with open(os.path.join(options.log_dir_path, "params.txt"), "w") as file:
        for key in sorted(vars(options)):
            value = getattr(options, key)
            logging.info(f"{key}: {value}")
            file.write(f"{key}: {value}\n")
    eval_zt(options, logger)
    listener.stop()