# run_adaptation_benchmark.py

import os
import copy
import logging
import csv
import datetime
import math
import gc
import ctypes
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

# Internal Imports
import scripts.config as config
import scripts.data_setup as data_setup
import scripts.evaluation as evaluation

# Quantization Imports
from quantization.apply import (
    # PTQ
    apply_simple_ptq, apply_smoothquant, apply_igq_vit, 
    apply_qwt_ptq, apply_apq_ptq, apply_rotation_ptq, 
    apply_outlier_aware_ptq, apply_qvlm_ptq,
    # QAT
    apply_quantization_aware_training, 
    apply_learned_step_size_quantization,
    apply_rotation_lsq
)
from quantization.modules import QuantizedLinearLayer, QuantizedConv2d

logging.basicConfig(level=logging.INFO, format='INFO: %(message)s')
logger = logging.getLogger(__name__)

ACTIVE_DEVICE = config.TARGET_DEVICE

TARGET_DATASETS = ['cifar100', 'imagenet1kval']

CSV_FIELDS = [
    "Model_Key", "Adaptation_Mode", "Target_ID_Dataset", 
    "OOD_Forgetting_Dataset", "Quant_Method", "W_Bits", "A_Bits",
    "ID_Acc", "ID_ECE", "ID_Acc_Degradation",
    "OOD_Acc", "Catastrophic_Forgetting_Score"
]

# --- SHARED UTILS ---

def cleanup():
    """Aggressive memory cleanup."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    try:
        libc = ctypes.CDLL("libc.so.6")
        libc.malloc_trim(0)
    except Exception:
        pass

# --- METHOD REGISTRIES ---

ALL_PTQ_METHODS = {
    "Simple PTQ": (apply_simple_ptq, config.SIMPLE_PTQ_KWARGS),
    "SmoothQuant PTQ": (apply_smoothquant, {}),
    "IGQ-ViT PTQ": (apply_igq_vit, config.IGQ_KWARGS),
    "QwT PTQ": (apply_qwt_ptq, config.QWT_KWARGS),                 
    "APQ-ViT PTQ": (apply_apq_ptq, config.APQ_KWARGS),             
    "Rotation PTQ": (apply_rotation_ptq, {}),                      
    "OutlierAware PTQ": (apply_outlier_aware_ptq, config.OUTLIER_AWARE_KWARGS), 
    "Q-VLM": (apply_qvlm_ptq, config.QVLM_KWARGS)
}

ALL_QAT_METHODS = {
    "QAT": (apply_quantization_aware_training, config.QAT_KWARGS),
    "LSQ": (apply_learned_step_size_quantization, config.LSQ_KWARGS),
    "Rotation + LSQ": (apply_rotation_lsq, config.LSQ_KWARGS),
}

# --- ROBUST LORA IMPLEMENTATION ---

class LoRALayer(nn.Module):
    def __init__(self, original_layer, rank=4, alpha=4):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        in_dim = original_layer.in_features
        out_dim = original_layer.out_features
        
        # Ensure parameters are created on the same device and dtype as the original layer
        w_param = original_layer.weight
        device = w_param.device
        dtype = w_param.dtype
        
        self.lora_A = nn.Parameter(torch.zeros(rank, in_dim, device=device, dtype=dtype))
        self.lora_B = nn.Parameter(torch.zeros(out_dim, rank, device=device, dtype=dtype))
        
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
        self.original_layer.weight.requires_grad = False
        if self.original_layer.bias is not None:
            self.original_layer.bias.requires_grad = False

    def forward(self, x):
        base_out = self.original_layer(x)
        x_casted = x.to(self.lora_A.dtype)
        lora_out = (x_casted @ self.lora_A.T @ self.lora_B.T) * self.scaling
        return base_out + lora_out.to(base_out.dtype)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.original_layer, name)

def apply_lora(model, rank, alpha):
    """
    Replaces Linear layers in Visual and Text encoders with LoRALayer.
    Excludes submodules of MultiheadAttention.
    """
    params_to_train = []
    
    ignored_modules = set()
    for m in model.modules():
        if isinstance(m, nn.MultiheadAttention):
            for child in m.modules():
                if child is not m:
                    ignored_modules.add(child)

    targets = []
    if hasattr(model, 'visual'): targets.append(model.visual)
    if hasattr(model, 'transformer'): targets.append(model.transformer)
    elif hasattr(model, 'text'): targets.append(model.text)
    
    if not targets: targets = [model]
    
    for submodule in targets:
        modules_to_replace = {}
        
        for name, module in submodule.named_modules():
            if module in ignored_modules:
                continue
            
            if isinstance(module, nn.Linear):
                modules_to_replace[name] = module

        for name, module in modules_to_replace.items():
            parent_name, child_name = name.rsplit('.', 1) if '.' in name else ('', name)
            
            if parent_name:
                parent = submodule.get_submodule(parent_name)
                setattr(parent, child_name, LoRALayer(module, rank, alpha))
    
    for n, p in model.named_parameters():
        if 'lora_' in n:
            p.requires_grad = True
            params_to_train.append(p)
        else:
            p.requires_grad = False
            
    if hasattr(model, 'logit_scale'):
        model.logit_scale.requires_grad = True
        params_to_train.append(model.logit_scale)
            
    logger.info(f"LoRA Applied. Training {len(params_to_train)} parameter tensors.")
    return model, params_to_train

def merge_lora(model):
    """Merges LoRA weights back into nn.Linear."""
    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, LoRALayer):
                delta = (module.lora_B @ module.lora_A) * module.scaling
                module.original_layer.weight.add_(delta)
                
    def recursive_swap(submodule):
        for name, child in submodule.named_children():
            if isinstance(child, LoRALayer):
                setattr(submodule, name, child.original_layer)
            else:
                recursive_swap(child)
    
    recursive_swap(model)
    return model

# --- CONTRASTIVE TRAINING ---

def train_contrastive(model, loader, tokenizer, device, steps, lr, mode="full"):
    model.train()
    
    if mode == "full":
        params = [p for p in model.parameters() if p.requires_grad]
    else:
        params = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=0.01)
    scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)
    loss_fn = nn.CrossEntropyLoss()
    
    class_names = getattr(loader.dataset, 'classes', None)
    if class_names is None:
        if hasattr(loader.dataset, 'dataset'):
             class_names = loader.dataset.dataset.classes
             
    template = "a photo of a {}"
    
    pbar = tqdm(range(steps), desc=f"Adapting ({mode})")
    iter_loader = iter(loader)
    
    for _ in pbar:
        try:
            batch = next(iter_loader)
        except StopIteration:
            iter_loader = iter(loader)
            batch = next(iter_loader)
            
        images, labels = batch
        images = images.to(device, non_blocking=True)
        labels = labels.to(device)
        
        unique_labels = torch.unique(labels)
        label_map = {old_id.item(): new_id for new_id, old_id in enumerate(unique_labels)}
        batch_targets = torch.tensor([label_map[l.item()] for l in labels], device=device)
        
        current_prompts = [template.format(class_names[l_idx]) for l_idx in unique_labels.cpu().tolist()]
        tokenized_text = tokenizer(current_prompts).to(device)
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            image_features = F.normalize(model.encode_image(images), dim=-1)
            text_features = F.normalize(model.encode_text(tokenized_text), dim=-1)
            
            logit_scale = model.logit_scale.exp()
            logits = logit_scale * image_features @ text_features.t()
            loss = loss_fn(logits, batch_targets)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
    return model.eval()

# --- EVALUATION HELPER ---

def evaluate_id_ood(model, id_key, ood_key, tokenizer, eval_suite):
    results = {}
    
    for key, data in eval_suite.items():
        if key not in [id_key, ood_key]: continue
        
        loader = data['loader']
        class_names = data['class_names']
        template = data['template']
        
        prompts = [template.format(c) for c in class_names]
        text_feats = []
        with torch.no_grad():
            for i in range(0, len(prompts), 256):
                tk = tokenizer(prompts[i:i+256]).to(ACTIVE_DEVICE)
                tf = model.encode_text(tk)
                text_feats.append(tf / tf.norm(dim=-1, keepdim=True))
        text_feats = torch.cat(text_feats)
        
        metrics = evaluation.run_comprehensive_evaluation(
            model, data['ref_model'], loader, text_feats, ACTIVE_DEVICE
        )
        results[key] = metrics
        
    return results[id_key], results[ood_key]

# --- MAIN ---

def main():
    model_key = os.environ.get("TARGET_MODEL_KEY", "Unknown_Model")
    quant_scope = os.environ.get("TARGET_QUANT_SCOPE", "ALL")
    job_id = os.environ.get("SLURM_JOB_ID", f"PID{os.getpid()}")
    
    safe_model_key = model_key.replace(" ", "_").replace("/", "-")
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    os.makedirs("results", exist_ok=True)
    csv_filename = f"results/forgetness_results_{safe_model_key}_{quant_scope}_{timestamp}_{job_id}.csv"
    
    print(f"\n{'='*60}")
    print(f"STARTING ADAPTATION & FORGETTING BENCHMARK")
    print(f"Logging to: {csv_filename}")
    print(f"Targets: {config.ADAPTATION_TARGETS}")
    print(f"{'='*60}\n")
    
    with open(csv_filename, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=CSV_FIELDS)
        writer.writeheader()
    
    base_model, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    base_model.to(ACTIVE_DEVICE).eval()
    
    dataset_map = {
        'cifar100': 'imagenet1kval',
        'imagenet1kval': 'cifar100'
    }
    
    eval_suite = {}
    print("Pre-loading Evaluation Datasets...")
    for d_name in TARGET_DATASETS:
        _, test_loader, class_names, template = data_setup.get_dataset_loaders(d_name, preprocess, get_train=False)
        eval_suite[d_name] = {
            'loader': test_loader, 'class_names': class_names, 'template': template,
            'ref_model': copy.deepcopy(base_model).cpu()
        }

    for target_id in TARGET_DATASETS:
        if target_id not in dataset_map:
            print(f"Skipping {target_id}, not in defined ID-OOD map.")
            continue
            
        target_ood = dataset_map[target_id]
        print(f"\n>>> BENCHMARK: ID={target_id} | OOD (Forgetting)={target_ood}")

        loaders, _, _, _ = data_setup.get_dataset_loaders(target_id, preprocess, get_train=True)
        if loaders[1] is None:
            print(f"Error: Could not load training data for {target_id}")
            continue
            
        train_loader = loaders[1]
        calib_loader = loaders[0]
        
        modes = ["LoRA", "Baseline", "Full_FT"]
        
        for mode in modes:
            print(f"\n--- Processing Mode: {mode} ---")
            
            # A. Prepare Model State
            current_model = copy.deepcopy(base_model).to(ACTIVE_DEVICE)
            
            if mode == "Full_FT":
                for p in current_model.parameters(): p.requires_grad = True
                current_model = train_contrastive(
                    current_model, train_loader, tokenizer, ACTIVE_DEVICE, 
                    config.ADAPTATION_CONFIG['steps'], config.ADAPTATION_CONFIG['full_ft_lr']
                )
                
            elif mode == "LoRA":
                current_model, _ = apply_lora(
                    current_model, 
                    config.ADAPTATION_CONFIG['lora_rank'], 
                    config.ADAPTATION_CONFIG['lora_alpha']
                )
                current_model = train_contrastive(
                    current_model, train_loader, tokenizer, ACTIVE_DEVICE, 
                    config.ADAPTATION_CONFIG['steps'], config.ADAPTATION_CONFIG['lora_lr'],
                    mode="lora"
                )
                current_model = merge_lora(current_model)
                
            # B. FP32 Evaluation
            id_metrics_fp32, ood_metrics_fp32 = evaluate_id_ood(
                current_model, target_id, target_ood, tokenizer, eval_suite
            )
            fp32_id_acc = id_metrics_fp32['Zero-Shot Accuracy']
            fp32_ood_acc = ood_metrics_fp32['Zero-Shot Accuracy']
            
            print(f"[{mode} FP32] ID Acc: {fp32_id_acc:.4f} | OOD Acc: {fp32_ood_acc:.4f}")
            
            with open(csv_filename, 'a', newline='') as f:
                csv.DictWriter(f, fieldnames=CSV_FIELDS).writerow({
                    "Model_Key": model_key, "Adaptation_Mode": mode,
                    "Target_ID_Dataset": target_id, "OOD_Forgetting_Dataset": target_ood,
                    "Quant_Method": "FP32", "W_Bits": 32, "A_Bits": 32,
                    "ID_Acc": fp32_id_acc, "ID_ECE": id_metrics_fp32['ECE'], "ID_Acc_Degradation": 0.0,
                    "OOD_Acc": fp32_ood_acc, "Catastrophic_Forgetting_Score": "N/A"
                })

            # C. Quantization Loop
            for q_method in config.ADAPTATION_QUANT_METHODS:
                
                # Check both registries
                is_ptq = q_method in ALL_PTQ_METHODS
                is_qat = q_method in ALL_QAT_METHODS
                
                if not (is_ptq or is_qat):
                    continue

                for w_bit, a_bit in config.BIT_WIDTHS_TO_TEST:
                    
                    # Create Student Copy
                    q_student = copy.deepcopy(current_model)
                    
                    # Prepare QAT Teacher (Only if needed)
                    # We use the current FP32 adapted model as the teacher to preserve the adaptation
                    local_teacher = None
                    if is_qat:
                        local_teacher = copy.deepcopy(current_model)
                        local_teacher.eval()
                        # Freeze teacher
                        for p in local_teacher.parameters(): p.requires_grad = False
                    
                    try:
                        # 1. Apply Method
                        if is_ptq:
                            apply_fn, kwargs = ALL_PTQ_METHODS[q_method]
                            q_model = apply_fn(
                                q_student, 
                                calibration_dataloader=calib_loader, 
                                target_device=ACTIVE_DEVICE,
                                tokenizer=tokenizer,
                                prompts=eval_suite[target_id]['class_names'], 
                                weight_bits=w_bit, act_bits=a_bit,
                                **kwargs
                            )
                        elif is_qat:
                            apply_fn, kwargs = ALL_QAT_METHODS[q_method]
                            q_model = apply_fn(
                                q_student,
                                training_dataloader=train_loader,
                                calibration_dataloader=calib_loader,
                                target_device=ACTIVE_DEVICE,
                                tokenizer=tokenizer,
                                prompts=eval_suite[target_id]['class_names'],
                                teacher=local_teacher,
                                weight_bits=w_bit, act_bits=a_bit,
                                **kwargs
                            )
                        
                        # 2. Eval Quantized
                        id_m_q, ood_m_q = evaluate_id_ood(q_model, target_id, target_ood, tokenizer, eval_suite)
                        
                        q_id_acc = id_m_q['Zero-Shot Accuracy']
                        q_ood_acc = ood_m_q['Zero-Shot Accuracy']
                        degradation = fp32_id_acc - q_id_acc
                        
                        # 3. Log
                        with open(csv_filename, 'a', newline='') as f:
                            csv.DictWriter(f, fieldnames=CSV_FIELDS).writerow({
                                "Model_Key": model_key, "Adaptation_Mode": mode,
                                "Target_ID_Dataset": target_id, "OOD_Forgetting_Dataset": target_ood,
                                "Quant_Method": q_method, "W_Bits": w_bit, "A_Bits": a_bit,
                                "ID_Acc": q_id_acc, "ID_ECE": id_m_q['ECE'], 
                                "ID_Acc_Degradation": degradation,
                                "OOD_Acc": q_ood_acc, 
                                "Catastrophic_Forgetting_Score": fp32_ood_acc - q_ood_acc 
                            })
                            
                    except Exception as e:
                        print(f"Error in {mode} {q_method}: {e}")
                    
                    # Cleanup
                    if 'q_model' in locals(): del q_model
                    del q_student
                    if local_teacher: del local_teacher
                    cleanup()
            
            del current_model
            cleanup()

if __name__ == "__main__":
    main()