import os
import argparse
import logging
import torch
import wandb
import utils.config as config
from .logger import get_logger, set_logger
from .task_vectors import TaskVector
from .evaluate import evaluate
from .data import get_eval_test_dataloader
from pkgs.openai.clip import load as load_model
from src.eval_retrieval import evaluate as eval_retrieval
from src.eval_retrieval import build_dataset, image_captions_collate_fn
from torch.utils.data import DataLoader, DistributedSampler



def eval_unlearn(options, logger):
    
    # Set the logger
    set_logger(rank = 0, logger = logger, distributed = options.distributed)
    logging.getLogger().setLevel(logging.INFO)
    logging.info("Starting evaluation for task unlearning.")
    logging.info(f"Options: {options}")

    # if(options.wandb and options.master):
    #     logging.debug("Starting wandb")
    #     wandb.init(project = ".", tags = [], config = vars(options), entity = '.')
    #     wandb.run.name = options.name
    #     wandb.save(os.path.join(options.log_dir_path, "params.txt"))

    # Validate checkpoint paths
    if not os.path.isfile(options.start_checkpoint):
        raise FileNotFoundError(f"Start checkpoint file not found: {options.start_checkpoint}")
    if not os.path.isfile(options.end_checkpoint):
        raise FileNotFoundError(f"End checkpoint file not found: {options.end_checkpoint}")
    if options.eval_checkpoint is not None and not os.path.isfile(options.eval_checkpoint):
        raise FileNotFoundError(f"Eval checkpoint file not found: {options.eval_checkpoint}")

    # Load checkpoints
    start_checkpoint = torch.load(options.start_checkpoint, map_location = options.device)
    end_checkpoint = torch.load(options.end_checkpoint, map_location = options.device)
    if options.eval_checkpoint is not None:
        eval_checkpoint = torch.load(options.eval_checkpoint, map_location = options.device)
    else:
        eval_checkpoint = start_checkpoint

    # drop module. & model. from the keys
    new_state_dict = {}
    for key in start_checkpoint["state_dict"]:
        new_key = key.replace("module.", "")
        new_key = new_key.replace("model.", "")
        new_state_dict[new_key] = start_checkpoint["state_dict"][key]
    start_checkpoint["state_dict"] = new_state_dict

    new_state_dict = {}
    for key in end_checkpoint["state_dict"]:
        new_key = key.replace("module.", "")
        new_key = new_key.replace("model.", "")
        new_state_dict[new_key] = end_checkpoint["state_dict"][key]
    end_checkpoint["state_dict"] = new_state_dict

    if options.eval_checkpoint is not None:
        new_state_dict = {}
        for key in eval_checkpoint["state_dict"]:
            new_key = key.replace("module.", "")
            new_key = new_key.replace("model.", "")
            new_state_dict[new_key] = eval_checkpoint["state_dict"][key]
        eval_checkpoint["state_dict"] = new_state_dict
        

    # eval_epoch = start_checkpoint['epoch']
    eval_epoch = eval_checkpoint['epoch']

    # Load the processor
    model, processor = load_model(name = options.model_name, pretrained = True)
    model.to(options.device)
    if options.eval_zt :
        start_checkpoint["state_dict"] = model.state_dict()
        eval_checkpoint["state_dict"] = model.state_dict()

    # Load the data
    data = {}
    data["validation"] = None
    data["eval_train"] = None
    data["eval_test"] = get_eval_test_dataloader(options, processor)
    
    # Load the task vectors
    task_vector = TaskVector(pretrained_state_dict=start_checkpoint["state_dict"], finetuned_state_dict=end_checkpoint["state_dict"])

    # Normalize the task vector
    if options.use_norm:
        task_vector =  task_vector / task_vector.norm()

    # Negate the task vector
    task_vector = -task_vector

    if options.reverse_sign:
        task_vector = -task_vector

    # Evaluate the model
    eval_mode = "ASR" if options.asr else "CA"
    base_name = options.name 

    if options.retrieval:
        dataset = build_dataset("mscoco_captions", root = "./data/COCO", split = "val", transform = processor.process_image)
        
        # sampler = DistributedSampler(dataset) if(options.distributed) else None

        dataloader = DataLoader(
            dataset, 
            batch_size = options.batch_size, 
            shuffle = False, 
            num_workers = options.num_workers, 
            # pin_memory = True, 
            collate_fn = image_captions_collate_fn)
        
        new_state_dict = task_vector.apply_to(pretrained_state_dict=eval_checkpoint["state_dict"], scaling_coef=options.alpha)
        model.load_state_dict(new_state_dict)
        model.eval()
        for idx in range(options.num_evaluations):
            options.name = base_name + f"_alpha_{options.alpha}" + f"_{eval_mode}" + f"_eval_num_{idx}"
            metrics = eval_retrieval(model, dataloader, processor.process_text, options.device, amp = False, recall_k_list = [5])
            print(metrics)
            options.name = base_name
            logging.info(f"Metrics for {base_name + '_alpha_' + str(options.alpha) + '_' + eval_mode + '_num_eval_'+ str(idx)} : {metrics}")
        # metrics = eval_retrieval(model, dataloader, processor.process_text, options.device, amp = False, recall_k_list = [5])
        # print(metrics)

        return 

    if options.eval_grid:
    
        for alpha in options.grid:
            # Apply the task vector to the pretrained model
            new_state_dict = task_vector.apply_to(pretrained_state_dict=eval_checkpoint["state_dict"], scaling_coef=alpha)
            model.load_state_dict(new_state_dict)
           
            # Evaluate the model
            model.eval()
            for idx in range(options.num_evaluations):
                options.name = base_name + f"_alpha_{alpha}" + f"_{eval_mode}" + f"_eval_num_{idx}"
                metrics = evaluate(eval_epoch, model, processor, data, options)
                print(metrics)
                options.name = base_name
                logging.info(f"Metrics for {base_name + '_alpha_' + str(alpha) + '_' + eval_mode + '_eval_num_'+ str(idx)} : {metrics}")
    else:

        # Apply the task vector to the pretrained model
        new_state_dict = task_vector.apply_to(pretrained_state_dict=eval_checkpoint["state_dict"], scaling_coef=options.alpha)
        model.load_state_dict(new_state_dict)

        # Evaluate the model
        model.eval()
        for idx in range(options.num_evaluations):
            options.name += f"_alpha_{options.alpha}" + f"_{eval_mode}" + f"_eval_num_{idx}"
            metrics = evaluate(eval_epoch, model, processor, data, options)
            print(metrics)
            options.name = base_name
            logging.info(f"Metrics for {base_name + '_alpha_' + str(options.alpha) + '_' + eval_mode + '_num_eval_'+ str(idx)} : {metrics}")

    if options.labelDiscovery:
        # save metrics to the logginf folder as a .pt file
        metrics_path = os.path.join(options.log_dir_path, "metrics.pt")
        torch.save(metrics, metrics_path)
        logging.info(f"Metrics saved to {metrics_path}")
    return

if __name__ == "__main__":
    # Define the argument parser
    parser = argparse.ArgumentParser(description='Evaluate task unlearning for a model.')

    parser.add_argument("--name", type=str, default="unlearn_default", help="Name of the experiment.")
    parser.add_argument("--model_name", type=str, default="RN50", choices=["RN50", "RN101", "RN50x4", "ViT-B/32", "ViT-L/14"], help="Name of the model architecture.")
    parser.add_argument("--logs", type = str, default = os.path.join(config.root, "logs/"), help = "Logs directory path")
    parser.add_argument("--start_checkpoint", type=str, required=True, help="Path to the checkpoint to start from.")
    parser.add_argument("--end_checkpoint", type=str, required=True, help="Path to the final fine-tuned checkpoint.")
    parser.add_argument("--eval_checkpoint", type=str, default = None, help="Path to the ckpt to which you apply the task vector")
    parser.add_argument("--alpha", type=float, default=0, help="Scaling coefficient for the task vector.")
    parser.add_argument("--eval_data_type", type=str, default="ImageNet1K", help="Type of evaluation dataset.")
    parser.add_argument("--eval_test_data_dir", type=str, required=True, help="Directory of the test dataset.")
    parser.add_argument("--device", type = str, default = None, choices = ["cpu", "gpu"], help = "Specify device type to use (default: gpu > cpu)")
    parser.add_argument("--device_id", type = int, default = 0, help = "Specify device id if using single gpu")
    parser.add_argument("--distributed", action = "store_true", default = False, help = "Use multiple gpus if available")
    parser.add_argument("--num_workers", type = int, default = 8, help = "Number of workers per gpu")
    parser.add_argument("--batch_size", type = int, default = 1024, help = "Batch size")
    parser.add_argument("--patch_type", default = None, type = str, help = "patch type of backdoor")
    parser.add_argument("--patch_location", default = None, type = str, help = "patch location of backdoor")
    parser.add_argument("--patch_size", default = 16, type = int, help = "patch size of backdoor")
    parser.add_argument("--patch_path", default = None, type = str, help = "path to patch")
    parser.add_argument("--add_backdoor", default = False, action = "store_true", help = "add backdoor or not")
    parser.add_argument("--backdoor_sufi", action = "store_true", default = False, help = "backdoor sufi")
    parser.add_argument("--asr", default = False, action = "store_true", help = "Calculate Attack Success Rate (ASR)")
    parser.add_argument("--eval_grid", default = False, action = "store_true", help = "used for evaluating a grid of alpha values for unlearning")
    parser.add_argument("--grid", type=lambda x: list(map(float, x.split(','))), default = [0,1,2,3,4], help = "define the grid of alpha values")
    parser.add_argument("--csv_path", type = str, default = None, help = "path to where you want to save the csv file")
    parser.add_argument("--use_norm", default = False, action = "store_true", help = "Normalize the task vector")
    parser.add_argument("--retrieval", default = False, action = "store_true", help = "Evaluate retrieval")
    parser.add_argument("--labelDiscovery", default = False, action = "store_true", help = "Evaluate label discovery")
    parser.add_argument("--num_evaluations", type = int, default = 1, help = "Number of evaluations to perform")
    parser.add_argument("--reverse_sign", default = False, action = "store_true", help = "Reverse the sign of the task vector")
    parser.add_argument("--target_class", type = int, default=954, help = "Target class for unlearning")
    parser.add_argument("--eval_zt", default = False, action = "store_true", help = "Evaluate on zt")
    parser.add_argument("wandb", action = "store_true", default = False, help = "Use wandb for logging")

    # Parse arguments
    options = parser.parse_args()

    options.log_dir_path = os.path.join(options.logs, options.name)
    options.log_file_path = os.path.join(options.log_dir_path, "output.log")
    options.device = "cuda" if torch.cuda.is_available() else "cpu"
    rank = 0
    options.master = rank == 0
    options.wandb = False
    os.makedirs(options.log_dir_path, exist_ok = True)
    logger, listener = get_logger(options.log_file_path)

    listener.start()
    with open(os.path.join(options.log_dir_path, "params.txt"), "w") as file:
        for key in sorted(vars(options)):
            value = getattr(options, key)
            logging.info(f"{key}: {value}")
            file.write(f"{key}: {value}\n")
    eval_unlearn(options, logger)
    listener.stop()

     