import argparse, os, sys 
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from torch_uncertainty.datamodules.classification import ImageNetDataModule
from torch.utils.data import DataLoader
import open_clip
from pruning.prune_model.structured import structurally_prune_attention_heads
from models.utils import get_prompts, get_text_logits
import torch 
import json 
import random

PRUNE_BATCH_SIZE = 32 
NUM_WORKERS = 4
SEEDS = [0, 521, 902] 
CLIP_MODEL = (
    "ViT-B-32",         
    "laion2b_s34b_b79k" 
)

def get_dataloader(seed, val_tfms):
    g = torch.Generator()
    g.manual_seed(seed)
    dm = ImageNetDataModule(
            root="data/in1k_torch_uncertainty",
            batch_size=PRUNE_BATCH_SIZE,
            test_transform=val_tfms,
            num_workers=NUM_WORKERS,
            eval_ood=False,
            pin_memory=True,
            persistent_workers=False
        )
    dm.setup("fit")
    dataset = dm.train_dataloader()[0].dataset
    return DataLoader(
        dataset, 
        batch_size=PRUNE_BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        generator=g
    )
    
def get_circuit_heads(
        circuit_path, 
        num_heads, 
    ):
    """
    Get the attention heads to prune from the circuit file.
    circuit_path: json file of the heads pruned
    num_heads: the number of heads to get from the circuit file in order from the least important head.
    =============
    For example, to get the 30 least important heads and use only 24 heads sampled randomly one should set
    num_heads = 30 and sample_n_heads = 24.
    """
    with open(circuit_path, "r") as f:
        circuit_data = json.load(f)

    candidates = []
    for pruning_step in circuit_data:
        if "pruned" not in pruning_step:
            continue
        candidates.append(pruning_step["pruned"])
        if len(candidates) >= num_heads:
            break
    
    heads_to_prune = dict()
    for layer, head in candidates:
        heads_to_prune.setdefault(layer, []).append(head)

    return heads_to_prune

def get_parser():
    parser = argparse.ArgumentParser(description="Get pruned checkpoint for CLIP")
    parser.add_argument('--heads_to_prune', type=int, default=0, help='Number of attention heads to prune')
    parser.add_argument('--strategy', type=str, default = "gradient", choices=['gradient', 'predefined'], 
                        help='Pruning strategy to use')
    parser.add_argument('--out_dir', type=str, required = True, help='Output directory for the pruned checkpoint')
    parser.add_argument('--circuit_path', type=str, default=None, help='Path to the circuit json file')
    parser.add_argument('--prune_text_encoder', action='store_true', default=False, 
                        help='Whether to prune the text encoder of CLIP (only if model is clip)')
    return parser

def main(args):

    num_heads_to_prune = args.heads_to_prune
    pruning_strategy = args.strategy 
    out_dir = args.out_dir
    circuit_path = args.circuit_path
    prune_text_encoder = args.prune_text_encoder

    assert num_heads_to_prune > 0, f"Heads to prune must be non-negative, instead got {num_heads_to_prune}"

    model, _, val_tfms = open_clip.create_model_and_transforms(CLIP_MODEL[0], 
                                                           pretrained=CLIP_MODEL[1])
    
    ckpt = {"state_dict": model.state_dict()}
    torch_device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
    model.to(torch_device)
    model.eval()
    tokenizer = open_clip.get_tokenizer(CLIP_MODEL[0])
    prompts = get_prompts(dataset = "imagenet-1k")
    text_logits = get_text_logits(prompts, model, tokenizer, device=torch_device)
    
    for seed in SEEDS:
        random.seed(seed)
        if pruning_strategy == "gradient":
            train_loader = get_dataloader(seed, val_tfms)
            heads_to_prune = num_heads_to_prune
        else:
            train_loader = None
            heads_to_prune = get_circuit_heads(
                circuit_path = circuit_path, 
                num_heads = num_heads_to_prune, 
                )
        
        pruned_model = structurally_prune_attention_heads(model, 
                                                          heads_to_prune, 
                                                          model_name = "OpenCLIP", 
                                                          strategy=pruning_strategy, 
                                                          dataloader=train_loader, 
                                                          device="cuda",
                                                          text_logits=text_logits, 
                                                          prune_text_encoder=prune_text_encoder)
        
        if pruning_strategy == "gradient":
            dataset_name = "in1k"
        elif "in1k" in circuit_path and "cifar100" in circuit_path:
            dataset_name = "in1k_1k_cifar100"
        elif "in1k" in circuit_path:
            dataset_name = "in1k"
        elif "cifar100" in circuit_path:
            dataset_name = "cifar100"
        dataset_name += "_pruned"
        out_dir = os.path.join(out_dir, f"{dataset_name}/{CLIP_MODEL[1]}")
        if prune_text_encoder:
            out_dir = os.path.join(out_dir, "text_pruned")
        os.makedirs(out_dir, exist_ok=True)

        pruned_ckpt_name = "CLIP"
        pruned_ckpt_name += f"_Heads{num_heads_to_prune}"
        if circuit_path:
            pruned_ckpt_name += f"_circuit_{os.path.splitext(os.path.basename(circuit_path))[0]}"
        pruned_ckpt_name += f"_strategy_{pruning_strategy}_seed{seed}.ckpt"
        
        ckpt_path = os.path.join(out_dir, pruned_ckpt_name)
        
        state = pruned_model.state_dict()
        ckpt["state_dict"] = state
        try:
            ckpt["pruning"] = {
                "heads": pruned_model.get_active_heads("OpenCLIP") if heads_to_prune else None,
            }
        except:
            pass 
        
        torch.save(ckpt, ckpt_path)
        print(f"Pruned checkpoint saved to {ckpt_path}")
        if pruning_strategy == "predefined": break
        print("=" * 50)

if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    main(args)