#!/usr/bin/env python3
"""
NPT (Nuisance-Prompt Tuning) experiment script that implements the proposed method
for few-shot OOD detection with learnable nuisance prompts.
"""

import argparse
import os
import sys
import torch
import time
import datetime
import numpy as np
import pandas as pd
import csv
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from torch.nn import functional as F

from utils.train_util import setup_logger, set_random_seed, AverageMeter, load_clip_to_cpu
from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
from datasets.train.imagenet import ImageNetDataset
from datasets.eval.test_loader import set_test_loader
import clip_w_local
from clip_w_local import clip
from clip_w_local.simple_tokenizer import SimpleTokenizer as _Tokenizer
from trainers import build_optimizer, build_lr_scheduler
from utils.losses import compute_accuracy
from utils.eval_util import get_and_print_results, add_results, add_overall_results, save_results_to_json
from configs.implemention import get_cfg_default
from yacs.config import CfgNode as CN

from npt_models import NPTCustomCLIP, extract_attention_weights


class NPTExperiment:
    """Class to manage NPT experiments"""
    
    def __init__(self, cfg: CN, args: argparse.Namespace):
        self.cfg = cfg
        self.args = args
        self.device = self._setup_device()
        self.model = None
        self.train_loader = None
        self.optimizer = None
        self.scheduler = None
        
    def _setup_device(self) -> torch.device:
        """Setup device"""
        if torch.cuda.is_available() and self.cfg.USE_CUDA:
            torch.backends.cudnn.benchmark = True
            return torch.device("cuda")
        return torch.device("cpu")
    
    def _validate_config(self) -> None:
        """Validate configuration"""
        if not hasattr(self.cfg, 'DATASET') or not hasattr(self.cfg.DATASET, 'ROOT'):
            raise ValueError("DATASET.ROOT must be specified in config")
        
        if not os.path.exists(self.cfg.DATASET.ROOT):
            raise ValueError(f"Dataset path does not exist: {self.cfg.DATASET.ROOT}")
            
        if not hasattr(self.cfg, 'OUTPUT_DIR'):
            raise ValueError("OUTPUT_DIR must be specified in config")
    
    def _create_train_transform(self) -> Compose:
        """Create training data transforms"""
        if hasattr(self.cfg, 'INPUT') and hasattr(self.cfg.INPUT, 'SIZE'):
            resize_size = tuple(self.cfg.INPUT.SIZE)
        else:
            resize_size = (224, 224)
            
        train_transform = []
        if "random_resized_crop" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(RandomResizedCrop(resize_size))
        if "random_flip" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(RandomHorizontalFlip())
        train_transform.append(ToTensor())
        if "normalize" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(Normalize(mean=self.cfg.INPUT.PIXEL_MEAN, std=self.cfg.INPUT.PIXEL_STD))
        
        return Compose(train_transform)
    
    def setup_data(self) -> None:
        """Setup dataset and data loaders"""
        print("Setting up data loaders...")
        
        # Prepare training dataset
        num_shots = getattr(self.cfg.DATASET, 'NUM_SHOTS', 1)
        seed = getattr(self.cfg, 'SEED', 1)
        
        train_transform = self._create_train_transform()
        
        self.train_dataset = ImageNetDataset(
            root=self.cfg.DATASET.ROOT,
            split="train",
            num_shots=num_shots,
            seed=seed,
            transform=train_transform
        )
        
        # Prepare data loader
        batch_size = self.cfg.DATALOADER.TRAIN_X.BATCH_SIZE if hasattr(self.cfg.DATALOADER, 'TRAIN_X') else 32
        num_workers = getattr(self.cfg.DATALOADER, 'NUM_WORKERS', 8)
        
        print(f"Using {num_workers} workers for data loading")
        
        self.train_loader = DataLoader(
            self.train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=True if num_workers > 0 else False,
            prefetch_factor=2 if num_workers > 0 else None
        )
        
        print(f"Train dataset size: {len(self.train_dataset)}")
    
    def setup_model(self) -> None:
        """Setup model and optimizer"""
        print("Setting up NPT model and optimizer...")
        
        # Load CLIP model
        print(f"Loading CLIP (backbone: {self.cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(self.cfg)
        
        if self.cfg.TRAINER.LOCOOP.PREC in ["fp32", "amp"]:
            clip_model.float()
        
        # Build NPT custom CLIP model
        classnames = self.train_dataset.classnames
        print("Building NPT Custom CLIP")
        self.model = NPTCustomCLIP(self.cfg, classnames, clip_model)
        
        # Disable gradients except for prompt learner
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        
        self.model.to(self.device)
        
        # Build optimizer and scheduler
        self.optimizer = build_optimizer(self.model.prompt_learner, self.cfg.OPTIM)
        self.scheduler = build_lr_scheduler(self.optimizer, self.cfg.OPTIM)
        
        print("NPT Model setup completed")
    
    def extract_patch_attention(self, image):
        """Extract attention weights for patch weighting"""
        image_input = image.type(self.model.dtype)
        return extract_attention_weights(self.model.image_encoder, image_input)
    
    def train_epoch(self, epoch: int) -> Dict[str, float]:
        """Train for one epoch with NPT loss"""
        self.model.train()
        total_loss = 0
        total_acc = 0
        total_samples = 0
        batch_time = AverageMeter()
        data_time = AverageMeter()
        num_batches = len(self.train_loader)
        
        max_epoch = self.cfg.OPTIM.MAX_EPOCH
        
        # Loss tracking
        total_loss_global = 0
        total_loss_patch = 0 
        total_loss_margin = 0
        
        end = time.time()
        
        for batch_idx, batch in enumerate(self.train_loader):
            data_time.update(time.time() - end)
            images, labels, _ = batch
            images = images.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            
            # Forward pass
            logits, logits_local, text_features = self.model(images)
            
            # Extract attention weights for patch weighting
            try:
                attention_weights = self.extract_patch_attention(images)
            except:
                attention_weights = None
            
            # Get image features for loss computation
            image_features, local_features = self.model.image_encoder(images.type(self.model.dtype))
            
            # Normalize features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            local_features = local_features / local_features.norm(dim=-1, keepdim=True)
            
            # Compute NPT loss
            loss_dict = self.model.compute_npt_loss(
                image_features, local_features, text_features, labels, attention_weights
            )
            
            loss = loss_dict['loss_total']
            loss.backward()
            self.optimizer.step()

            # Calculate accuracy on class logits only
            acc = compute_accuracy(logits, labels)[0].item()
            
            # Update metrics
            batch_size = images.size(0)
            total_loss += loss.item() * batch_size
            total_acc += acc * batch_size / 100.0
            total_samples += batch_size
            
            total_loss_global += loss_dict['loss_global'].item() * batch_size
            total_loss_patch += loss_dict['loss_patch'].item() * batch_size
            total_loss_margin += loss_dict['loss_margin'].item() * batch_size

            batch_time.update(time.time() - end)
            end = time.time()

            # Calculate ETA
            nb_remain = (num_batches - batch_idx - 1) + (max_epoch - epoch - 1) * num_batches
            eta_seconds = batch_time.avg * nb_remain
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))

            # Display progress
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == num_batches:
                print(
                    f"epoch [{epoch + 1}/{max_epoch}] "
                    f"batch [{batch_idx + 1}/{num_batches}] "
                    f"time {batch_time.val:.3f} ({batch_time.avg:.3f}) "
                    f"data {data_time.val:.3f} ({data_time.avg:.3f}) "
                    f"loss {loss.item():.4f} "
                    f"l_global {loss_dict['loss_global'].item():.4f} "
                    f"l_patch {loss_dict['loss_patch'].item():.4f} "
                    f"l_margin {loss_dict['loss_margin'].item():.4f} "
                    f"acc {acc:.2f} "
                    f"lr {self.optimizer.param_groups[0]['lr']:.4e} "
                    f"eta {eta}"
                )

        self.scheduler.step()
        
        # Calculate epoch averages
        avg_loss = total_loss / total_samples
        avg_acc = 100 * total_acc / total_samples
        avg_loss_global = total_loss_global / total_samples
        avg_loss_patch = total_loss_patch / total_samples  
        avg_loss_margin = total_loss_margin / total_samples
        
        return {
            'loss': avg_loss,
            'accuracy': avg_acc,
            'loss_global': avg_loss_global,
            'loss_patch': avg_loss_patch,
            'loss_margin': avg_loss_margin
        }
    
    def train(self) -> None:
        """Train for all epochs"""
        print("Starting NPT training...")
        max_epoch = self.cfg.OPTIM.MAX_EPOCH
        
        for epoch in range(max_epoch):
            metrics = self.train_epoch(epoch)
            print(f"Epoch [{epoch+1}/{max_epoch}] "
                  f"Loss: {metrics['loss']:.4f} "
                  f"Acc: {metrics['accuracy']:.2f} "
                  f"L_global: {metrics['loss_global']:.4f} "
                  f"L_patch: {metrics['loss_patch']:.4f} "
                  f"L_margin: {metrics['loss_margin']:.4f}")
        
        print("NPT training completed")
    
    def load_model(self, model_path: str) -> None:
        """Load model"""
        print(f"Loading NPT model from {model_path}")
        self.model.prompt_learner.load_state_dict(torch.load(model_path))
        print("NPT model loaded successfully")
    
    def test_ood_detection(self, model: torch.nn.Module, data_loader: DataLoader, T: float = 1.0) -> np.ndarray:
        """Test OOD detection with NPT model"""
        to_np = lambda x: x.data.cpu().numpy()
        concat = lambda x: np.concatenate(x, axis=0)

        mcm_score = []
        
        for batch_idx, (images, labels, *id_flag) in enumerate(tqdm(data_loader, desc="Testing OOD with NPT")):
            images = images.to(self.device)
            model.eval()
            with torch.no_grad():
                # Use class logits only (exclude nuisance)
                output, output_local, _ = model(images)
            
            output /= 100.0
            smax_global = to_np(F.softmax(output/T, dim=-1))
            mcm_global_score = -np.max(smax_global, axis=1)
            mcm_score.append(mcm_global_score)

        return concat(mcm_score)[:len(data_loader.dataset)].copy()
    
    def evaluate(self) -> None:
        """Run evaluation"""
        print("Starting NPT evaluation...")
        
        # Set model to evaluation mode
        self.model.eval()
        
        # Prepare data loader
        _, preprocess = clip_w_local.load(self.cfg.MODEL.BACKBONE.NAME)
        self.args.in_dataset = "imagenet"
        self.args.batch_size = 512
        
        id_data_loader = set_test_loader(self.args, "imagenet", preprocess)
        
        # Calculate in-distribution scores
        in_score_mcm = self.test_ood_detection(self.model, id_data_loader, 1)
        
        # Lists for evaluation
        auroc_list_mcm, fpr_list_mcm = [], []
        results_data = []
        
        # Evaluate out-of-distribution datasets
        out_datasets = ['iNaturalist', 'SUN', 'places365', 'Texture']
        
        scores_dict: Dict[str, Dict[str, np.ndarray]] = {}
        scores_dict["MCM"] = {}
        scores_dict["MCM"]["ImageNet"] = in_score_mcm
        
        for out_dataset in out_datasets:
            print(f"Evaluating OOD dataset: {out_dataset}")
            ood_loader = set_test_loader(self.args, out_dataset, preprocess)
            out_score_mcm = self.test_ood_detection(self.model, ood_loader, 1)

            # Evaluate MCM score
            print("NPT MCM score")
            mcm_results = get_and_print_results(
                self.args, in_score_mcm, out_score_mcm,
                auroc_list_mcm, fpr_list_mcm
            )
            scores_dict["MCM"][out_dataset] = out_score_mcm
            
            # Save results
            results_data = add_results(results_data, mcm_results, out_dataset)

        # Add overall results
        results_data = add_overall_results(results_data, auroc_list_mcm, fpr_list_mcm)

        # Save scores to .npz
        np.savez(f"{self.args.output_dir}/scores.npz", **scores_dict)

        # Save results to JSON
        save_results_to_json(results_data, self.args.output_dir, "results.json")
        print("NPT evaluation completed")


def run_npt_experiment(output_dir_path: str) -> None:
    """
    Run the NPT experiment with predefined parameters.
    """
    
    # Set up arguments that match the baseline
    args = argparse.Namespace()
    
    # Core arguments
    args.root = "/datasets/LoCoOp"
    args.seed = 0
    args.trainer = "LoCoOp"
    args.dataset_config_file = "configs/datasets/imagenet.yaml"
    args.config_file = "configs/trainers/LoCoOp/vit_b16_ep30.yaml"
    
    # NPT parameters
    dataset = "imagenet"
    cfg = "vit_b16_ep30"
    shots = 1
    nctx = 16
    csc = False
    ctp = "end"
    
    args.output_dir = output_dir_path
    
    # NPT-specific hyperparameters
    args.lambda_patch = 0.25
    args.lambda_margin = 0.25
    args.margin = 0.2
    args.method = "npt"
    
    # Additional parameters
    args.resume = ""
    args.backbone = ""
    args.head = ""
    args.eval_only = False
    args.model_dir = ""
    args.load_epoch = None
    args.no_train = False
    args.T = 1.0
    args.sample_size = 500
    
    # Options for configuration overrides
    args.opts = [
        "TRAINER.LOCOOP.N_CTX", str(nctx),
        "TRAINER.LOCOOP.CSC", str(csc),
        "TRAINER.LOCOOP.CLASS_TOKEN_POSITION", ctp,
        "DATASET.NUM_SHOTS", str(shots)
    ]
    
    print("Running NPT experiment with parameters:")
    print(f"  Root: {args.root}")
    print(f"  Dataset: {dataset}")
    print(f"  Config: {cfg}")
    print(f"  Shots: {shots}")
    print(f"  Context tokens: {nctx}")
    print(f"  Lambda patch: {args.lambda_patch}")
    print(f"  Lambda margin: {args.lambda_margin}")
    print(f"  Margin: {args.margin}")
    print(f"  Method: {args.method}")
    print(f"  Output dir: {args.output_dir}")
    print()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Setup configuration
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # Merge configurations
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    
    reset_cfg(cfg, args)
    cfg.merge_from_list(args.opts)
    
    # Add NPT hyperparameters to config
    cfg.lambda_patch = args.lambda_patch
    cfg.lambda_margin = args.lambda_margin
    cfg.margin = args.margin
    
    cfg.freeze()
    
    # Set seed
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    
    # Setup logger
    setup_logger(cfg.OUTPUT_DIR)
    
    # Initialize experiment class
    experiment = NPTExperiment(cfg, args)
    
    # Validate configuration
    experiment._validate_config()
    
    # Prepare data
    experiment.setup_data()
    
    # Prepare model
    experiment.setup_model()
    
    # Load model for evaluation only
    if args.eval_only:
        if not args.model_dir:
            raise ValueError("--model-dir must be specified for eval-only mode")
        experiment.load_model(args.model_dir)
    
    # Run training
    if not args.eval_only:
        experiment.train()
    
    # Run evaluation
    experiment.evaluate()


def extend_cfg(cfg: CN) -> None:
    """Extend configuration"""
    cfg.TRAINER.LOCOOP = CN()
    cfg.TRAINER.LOCOOP.N_CTX = 16
    cfg.TRAINER.LOCOOP.CSC = False
    cfg.TRAINER.LOCOOP.CTX_INIT = ""
    cfg.TRAINER.LOCOOP.PREC = "fp16"
    cfg.TRAINER.LOCOOP.CLASS_TOKEN_POSITION = "end"

    cfg.DATASET.SUBSAMPLE_CLASSES = "all"


def reset_cfg(cfg: CN, args: argparse.Namespace) -> None:
    """Reset configuration"""
    if args.root:
        cfg.DATASET.ROOT = args.root
    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir
    if args.resume:
        cfg.RESUME = args.resume
    if args.seed:
        cfg.SEED = args.seed
    if args.trainer:
        cfg.TRAINER.NAME = args.trainer
    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone
    if args.head:
        cfg.MODEL.HEAD.NAME = args.head


def main():
    """Main entry point for NPT script."""
    parser = argparse.ArgumentParser(
        description="NPT experiment that implements nuisance-prompt tuning"
    )
    parser.add_argument(
        "--output-dir", 
        type=str, 
        required=True,
        help="Path to output directory where results will be saved"
    )
    
    args = parser.parse_args()
    
    # Run the NPT experiment
    run_npt_experiment(args.output_dir)


if __name__ == "__main__":
    main()