import sys
sys.path.append('/YOUR_ROOT_PATH/src') 
sys.path.append('/YOUR_ROOT_PATH/src/train') 
CHACHE_DIR = '/mnt/raid10/ak-research-01/ak-research-01/codes/.cache'

import numpy as np
np.float_ = np.float64
np.complex_ = np.complex128

from train.datasets import COCOFlickrDataset, ImageNetDataset 
from CLIP_eval.eval_utils import load_clip_model, load_clip_model_convnext 

sys.path.append("open_flamingo")
import os
import shutil
import time
import string
import random

import numpy as np
import open_clip
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from train.training.scheduler import cosine_lr
from torchvision import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL, IMAGENET_100_CLASS_ID_TO_LABEL
from train.pgd_train import pgd
from train.apgd_train import apgd_train as apgd
import wandb
from train.utils import init_wandb, AverageMeter
from train.sam_data import SamData
from open_flamingo.eval.models.utils import unwrap_model
from train.utils import str2bool
import argparse
import torchvision
from typing import Tuple, Dict, Union, Literal
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
from torchvision.datasets import CIFAR10, CIFAR100
import torch
from copy import deepcopy
from tqdm import tqdm 

from lambda_net import LambdaNetworkFactory

parser = argparse.ArgumentParser()
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32') 
parser.add_argument('--pretrained', type=str, default='openai')
parser.add_argument('--dataset', type=str, default='imagenet')
parser.add_argument('--template', type=str, default='std')
parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized')
parser.add_argument('--start_step', type=int, default=0, help='Start step for training')
parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path')
parser.add_argument('--steps', type=int, default=20000, help='Number of training steps')
parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps')
parser.add_argument('--per_device_batch_size', type=int, default=256)
parser.add_argument('--loss', type=str, default='l2', help='ce, l2, linf')
parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2, linf')
parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss')
parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES')
parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw')
parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay')
parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training')
parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation')
parser.add_argument('--rho', type=float, default=1e-1, help='Final rho value for constraints (scheduled from 0 to this value)')
parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack')
parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)')
parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
parser.add_argument('--experiment_name', type=str, default='')
parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory')
parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency')
parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency')
parser.add_argument('--output_dir', type=str, default='', help='Output directory')
parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints')
parser.add_argument('--online', type=str2bool, default=True, help='Attack is done offline')
parser.add_argument('--evaluation', type=str, default='robust', help='Evaluation method')
parser.add_argument('--checkpoint_dir', type=str, default=None, help='checkpoint for previous model')

# new
parser.add_argument('--kkt_helper', type=str2bool, default=False, help='Help the lambda network via KKT')
parser.add_argument('--lambda_lr', type=float, default=1e-2, help='Learning rate for lamnda function')
parser.add_argument('--k_iter', type=int, default=1, help="How many min steps before nax")
parser.add_argument('--lambda_net', type=str, help='The type of network used for lagrangian', default='simple')
parser.add_argument('--anchor_option', type=str, default='orig', help='option of anchor model')
parser.add_argument('--lambda_optimizer_state', type=str, default='', help='Lambda optimizer state file path')
parser.add_argument('--grad_norm', type=float, default=1.0, help='Gradient norm')
parser.add_argument('--rho_scheduler', type=str, default='const', help='')
parser.add_argument('--return_clean_dis_freq', type=int, default=10, help="")
parser.add_argument('--lagrangian_type', type=str, default='scalar', help='Type of Lagrangian multiplier: scalar or vector')
parser.add_argument('--corruption_type', type=str, default='brightness', help='')
parser.add_argument('--corruption_level', type=str, default='1', help='')
parser.add_argument('--num_eval_batches', type=int, default=20, help="")


# parser.add_argument('--lambda_threshold', type=float, default=5e-2, help='The threshold for when we care about lambda')

classes = {}

def main(args):
    # setup wandb
    pin_memory = False
    if args.wandb:
        init_wandb(
            project_name='LORE',
            model_name=args.finetuned_model_name,
            config=vars(args)
        )
    else:
        wandb.init(mode='disabled')

    #######################################################################################
    args.steps = args.steps // torch.cuda.device_count() 
    # NOTE: continue the training of a checkpoint with the same number of GPUs.
    
    if args.checkpoint_dir is None:
        temp_path = '/YOUR_ROOT_PATH/checkpoints/final/ViT-B-32-imagenet/ViT-B-32-pretrained'
        os.makedirs(f"{temp_path}/ood_results", exist_ok=True)
        sys.stdout = open(f'{temp_path}/ood_results/outputs.out', 'w') 
        sys.stderr = open(f'{temp_path}/ood_results/errors.err', 'w')
    else: 
        os.makedirs(f"{args.checkpoint_dir}/ood_results", exist_ok=True)
        sys.stdout = open(f'{args.checkpoint_dir}/ood_results/outputs.out', 'w') 
        sys.stderr = open(f'{args.checkpoint_dir}/ood_results/errors.err', 'w')
    #######################################################################################
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print(f'Number of GPUs available: {num_gpus}')
    else:
        print('No multiple GPUs available.')

    # # print args
    # print(f"Arguments:\n{'-' * 20}")
    # for arg, value in vars(args).items():
    #     print(f"{arg}: {value}")
    # print(f"{'-' * 20}")
    #######################################################################################
    main_device = 0
    # get models
    if args.clip_model_name == "hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg":
        model_orig, _, image_processor = open_clip.create_model_and_transforms(
        args.clip_model_name, cache_dir=CHACHE_DIR) 
    elif args.clip_model_name == "ViT-B-16-laion2B":
        model_orig, _, image_processor = open_clip.create_model_and_transforms(
        'hf-hub:laion/CLIP-ViT-B-16-laion2B-s34B-b88K', cache_dir=CHACHE_DIR)
    elif args.clip_model_name == "ViT-B-32-laion2B":
        model_orig, _, image_processor = open_clip.create_model_and_transforms(
        'hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K', cache_dir=CHACHE_DIR)
    else: 
        model_orig, _, image_processor = open_clip.create_model_and_transforms(args.clip_model_name, pretrained='openai',
                         cache_dir=CHACHE_DIR)  
    
    if args.optimizer_state != '':
        assert args.start_step > 0
        assert str(args.start_step) in args.optimizer_state
        assert args.pretrained in ['', 'none']
        args.pretrained = args.optimizer_state.replace('_opt', '')

    if args.clip_model_name == "hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg":
        model, _, _ = load_clip_model_convnext(args.clip_model_name) 
    elif args.clip_model_name == "ViT-B-16-laion2B":
        model, _, _ = load_clip_model_convnext('hf-hub:laion/CLIP-ViT-B-16-laion2B-s34B-b88K') 
    elif args.clip_model_name == "ViT-B-32-laion2B":
        model, _, _ = load_clip_model_convnext('hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K') 
    else: 
        model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)  


    # Remove the Normalize transform by creating a new Compose object
    preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
    normalize = image_processor.transforms[-1]
    del image_processor
    print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
    print(f'[normalize] {normalize}')
    
    args.batch_size = args.per_device_batch_size * num_gpus

    # get data
    dataset, dataset_eval = get_dataset(args.dataset, preprocessor_without_normalize,
                                        imagenet_root=f'{args.imagenet_root}/{args.corruption_type}/1')

    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, drop_last=True, pin_memory=pin_memory) 
    dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=16, drop_last=True, pin_memory=pin_memory)
    
    print('Datasets are loaded!')

    if args.template == 'std':
        template = 'This is a photo of a {}'
    elif args.template == 'blurry':
        template = 'This is a blurry photo of a {}'
    else:
        raise ValueError(f'Unknown template: {args.template}')
    print(f'template: {template}')
    texts = [template.format(c) for c in classes.values()]
    text_tokens = open_clip.tokenize(texts)
    model_orig.to(main_device)
    with torch.no_grad():
        embedding_text_labels_norm = []
        for el in (text_tokens[:(len(classes) // 2)], text_tokens[(len(classes) // 2):]):
            # we need to split the text tokens into two batches because otherwise we run out of memory
            # note that we are accessing the model directly here, not the CustomModel wrapper
            # thus its always normalizing the text embeddings
            embedding_text_labels_norm.append(
                model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
            )
        embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
        assert torch.allclose(
            F.normalize(embedding_text_labels_norm, dim=0),
            embedding_text_labels_norm
        )
        if args.clip_model_name in ('ViT-B-32', 'ViT-B-32-quickgelu', 'ViT-B-16'):
            assert embedding_text_labels_norm.shape == (512, len(classes)), embedding_text_labels_norm.shape
        elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
            assert embedding_text_labels_norm.shape == (768, len(classes)), embedding_text_labels_norm.shape
        elif args.clip_model_name in ('hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg', ):
            assert embedding_text_labels_norm.shape == (640, len(classes)), embedding_text_labels_norm.shape
        else:
            raise ValueError(f'Unknown model: {args.clip_model_name}')

    def to_parallel_and_cuda(model):
        if num_gpus > 1:
            model = torch.nn.DataParallel(model)
        return model.cuda()

    model_orig.cpu()
    model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
    model_orig = to_parallel_and_cuda(model_orig)

    model = ClipVisionModel(model=model.visual, args=args, normalize=normalize)
    model = to_parallel_and_cuda(model)
    
    lambda_network = LambdaNetworkFactory.create_network(
        args.lambda_net, 
        model_orig=model_orig, 
        clip_model_name=args.clip_model_name,
        lagrangian_type=args.lagrangian_type
    )
    lambda_network = to_parallel_and_cuda(lambda_network)
        
    # set optimizer (all params have requires_grad=True)
    params = unwrap_model(model).model.parameters()
    lambda_params = unwrap_model(lambda_network).parameters()

    if args.opt == 'adamw':
        optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
        optimizer_lambda = torch.optim.AdamW(lambda_params, lr=args.lambda_lr, weight_decay=args.wd, maximize=True) # Adam
    elif args.opt == 'sgd':
        optimizer = torch.optim.SGD(
            params,
            lr=args.lr,
            momentum=args.momentum_sgd,
            weight_decay=args.wd
        )
        optimizer_lambda = torch.optim.SGD(
            lambda_params,
            lr=args.lambda_lr,
            momentum=args.momentum_sgd,
            weight_decay=args.wd,
            maximize=True
        )
    else:
        raise ValueError(f'Optimizer {args.optimizer} not supported.')
    if args.optimizer_state != '':
        optimizer.load_state_dict(torch.load(args.optimizer_state))
        optimizer_lambda.load_state_dict(torch.load(args.lambda_optimizer_state))

    # set scheduler
    scheduler = cosine_lr(optimizer, args.lr, args.warmup//num_gpus, args.steps) 
    lambda_scheduler = None # cosine_lr(optimizer_lambda, args.lr, args.warmup, args.steps)

    if args.checkpoint_dir is not None:
        print('*'*50)
        # Construct paths for the model and optimizer checkpoints
        model_dir = os.path.join(args.checkpoint_dir, f'checkpoints/final.pt')
        opt_dir = os.path.join(args.checkpoint_dir, f'checkpoints/final_opt.pt')

        # Load model checkpoint
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        checkpoint = torch.load(model_dir, map_location=device)

        adjusted_state_dict = {f"model.{k}": v for k, v in checkpoint.items()}

        unwrap_model(model).load_state_dict(adjusted_state_dict)

        # Load optimizer checkpoints
        checkpoint_opt = torch.load(opt_dir, map_location=device)
        optimizer.load_state_dict(checkpoint_opt)
        
        print(f"Model restored successfully from {args.checkpoint_dir}")


    # compute amount of epochs 
    total_epochs = args.steps / len(dataloader) 
    args.total_epochs = total_epochs 
    
    # finetune
    step_total = args.start_step
    epoch = 0

    delayed_model = None
    ema = None 

    if args.anchor_option == 'EMA':
        decay = 0.990  # Decay rate for EMA
        ema = EMA(model, decay)
    elif args.anchor_option == 'sg_phi_delayed':
        update_interval = 15 
        delayed_model = DelayedModel(model, update_interval)
    else:
        pass 

    # while step_total < args.steps:
    step_total = train_one_epoch(  
        step_total,
        model=model,
        model_orig=model_orig,
        dataloader=dataloader,
        dataloader_eval=dataloader_eval,
        optimizer=optimizer, 
        scheduler=scheduler,
        embedding_text_labels_norm=embedding_text_labels_norm,
        normalize=normalize,
        args=args,
        epoch=epoch, 
        lambda_network=lambda_network,
        optimizer_lambda=optimizer_lambda, 
        lambda_scheduler=lambda_scheduler,
        ema=ema,
        delayed_model=delayed_model,
    )
        # print(f'Epoch {epoch} done.')
        # epoch += 1


class ClipVisionModel(torch.nn.Module):
    def __init__(self, model, args, normalize):
        super().__init__()
        self.model = model
        self.args = args
        self.normalize = normalize

    def forward(self, vision, output_normalize):
        embedding = self.model(self.normalize(vision))
        if output_normalize:
            embedding = F.normalize(embedding, dim=-1)
        return embedding


class ComputeLossWrapper:
    def __init__(self, embedding_orig,
                  embedding_text_labels_norm, reduction='mean', loss=None,
                 logit_scale=100.): 
        self.embedding_orig = embedding_orig
        self.embedding_text_labels_norm = embedding_text_labels_norm
        self.reduction = reduction
        self.loss_str = loss
        self.logit_scale = logit_scale

    def __call__(self, embedding, targets):
        return compute_loss(
            loss_str=self.loss_str, embedding=embedding, targets=targets,
            embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
            embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
        )[0]


##############################   NEW  ######################################
class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.ema_model = deepcopy(model)
        for param in self.ema_model.parameters():
            param.detach_()

    def update(self):
        with torch.no_grad():
            model_params = dict(self.model.named_parameters())
            ema_params = dict(self.ema_model.named_parameters())
            for name in model_params.keys():
                ema_params[name].data.mul_(self.decay)
                ema_params[name].data.add_((1 - self.decay) * model_params[name].data)

class DelayedModel:
    def __init__(self, model, update_interval):
        """
        Initializes the DelayedModel with the given model and update interval.

        Args:
            model (torch.nn.Module): The model to be tracked.
            update_interval (int): Number of steps between updates of the delayed model.
        """
        self.model = model
        self.update_interval = update_interval
        self.delayed_model = deepcopy(model)
        self.delayed_model.eval()  # Set the delayed model to evaluation mode
        self.counter = 0

    def update(self):
        """
        Updates the delayed model if the counter reaches the update interval.
        """
        self.counter += 1
        if self.counter >= self.update_interval:
            self.delayed_model.load_state_dict(self.model.state_dict())
            self.counter = 0


def get_anchor_embedding(data, args, model_orig=None, model=None, anchor_option='orig', 
                        ema=None, delayed_model=None, alpha=0.5,
                        time_step=None, total_step=None):
    """
    Computes the anchor embedding based on the specified anchor option.

    Args:
        data (torch.Tensor): Input data.
        args (Namespace): Arguments containing configuration.
        model_orig (torch.nn.Module, optional): Original model for 'orig' option.
        model (torch.nn.Module, optional): Current model for 'sg_phi' option.
        anchor_option (str): Specifies which anchor to use ('orig', 'sg_phi', 'EMA', 'sg_phi_delayed', 'proj_on_org').
        ema (EMA, optional): EMA object for 'EMA' option.
        delayed_model (DelayedModel, optional): DelayedModel object for 'sg_phi_delayed' option.

    Returns:
        torch.Tensor: The anchor embedding.
    """
    with torch.no_grad():
        if anchor_option == 'orig':
            embedding_orig = model_orig(vision=data, output_normalize=args.output_normalize)

        elif anchor_option == 'sg_phi':
            embedding_orig = model(vision=data, output_normalize=args.output_normalize)

        elif anchor_option == 'EMA':
            if ema is None:
                raise ValueError("EMA model is not initialized.")
            embedding_orig = ema.ema_model(vision=data, output_normalize=args.output_normalize)

        elif anchor_option == 'sg_phi_delayed': 
            if delayed_model is None:
                raise ValueError("Delayed model is not initialized.")
            embedding_orig = delayed_model.delayed_model(vision=data, output_normalize=args.output_normalize)
        
        elif anchor_option == 'proj_on_org':
            if model_orig is None:
                raise ValueError("Original model is not provided.")
            embedding_orig = model_orig(vision=data, output_normalize=args.output_normalize)
            embedding_current = model(vision=data, output_normalize=args.output_normalize) 
            # Calculate the difference between current and original embeddings
            difference = embedding_current - embedding_orig
            normalized_difference = F.normalize(difference, p=2, dim=-1)
            scaled_normalized_difference = alpha * normalized_difference
            embedding_orig += scaled_normalized_difference

        elif anchor_option == 'interpolated':
            if model_orig is None:
                raise ValueError("Original model is not provided.")
            embedding_orig = model_orig(vision=data, output_normalize=args.output_normalize)
            embedding_current = model(vision=data, output_normalize=args.output_normalize) 
            t = time_step / total_step
            if not (0 <= t <= 1):
                raise ValueError("Interpolation parameter t must be between 0 and 1.")

            embedding_orig = model_orig(vision=data, output_normalize=args.output_normalize)
            embedding_current = model(vision=data, output_normalize=args.output_normalize)
            # Interpolate between the original and current embeddings
            embedding_orig = t * embedding_orig + (1 - t) * embedding_current 

        else:
            raise ValueError(f"Unknown anchor_option: {anchor_option}") 
    
    return embedding_orig

# Cache for storing precomputed progress values
_rho_scheduler_cache = {} 

def calculate_scheduled_rho(current_step, total_steps, final_rho, increasing=True, scheduler='const'):
    """
    Calculate the scheduled rho value based on the current step.
    Rho starts at 0 and either increases to final_rho or decreases to 0 as training progresses.
    This version supports different scheduling methods: cosine, sqrt, linear, and sin.
    Uses caching for efficient repeated calculations.
    
    Args:
        current_step (int): Current training step
        total_steps (int): Total number of training steps
        final_rho (float): Final value of rho
        increasing (bool): If True, rho increases from 0 to final_rho; if False, it decreases from 
                           final_rho to 0.
        scheduler (str): The type of scheduler to use ('cosine', 'sqrt', 'linear', 'sin').
        
    Returns:
        float: Scheduled rho value for the current step 
    """
    # Create a cache key based on the current step and total steps
    cache_key = (current_step, total_steps, scheduler)

    if scheduler == 'const':
        return final_rho
    
    # Check if we have a cached value
    if cache_key in _rho_scheduler_cache:
        adjusted_progress = _rho_scheduler_cache[cache_key]
    else:
        progress = min(current_step / (0.25 * total_steps), 1.0)
        
        if scheduler == 'cosine':
            adjusted_progress = 1 - np.cos(progress * np.pi / 2)
        elif scheduler == 'sqrt':
            adjusted_progress = progress ** 0.5
        elif scheduler == 'linear':
            adjusted_progress = progress
        elif scheduler == 'sin':
            adjusted_progress = np.sin(progress * np.pi / 2)
        else:
            raise ValueError(f"Unknown scheduler: {scheduler}")
        
        # Store the computed value in cache
        _rho_scheduler_cache[cache_key] = adjusted_progress
    
    if increasing:
        return adjusted_progress * final_rho
    else:
        return final_rho * (1 - adjusted_progress)

def train_one_epoch(
        step_total, model, model_orig, dataloader, optimizer, scheduler, 
        normalize, embedding_text_labels_norm, args, epoch,
        lambda_network, optimizer_lambda,lambda_scheduler, 
        ema,delayed_model,
        dataloader_eval=None,
    ):  
    attack_model = model if args.online else model_orig
    model.eval()
    attack_model.eval()
    model_orig.eval()

    epoch_start_time = time.time()
    eval_logs = dict()
    # Define empty lists before your loop
    racc_eval_list = []
    acc_eval_list = []
    cos_sim_eval_list = []
    cos_sim_eval_clean_list = []
    image_cos_sim_eval_adv_list = []
    image_cos_sim_eval_clean_list = []

    for i, (data_eval, targets_eval) in tqdm(enumerate(dataloader_eval)):
        data_eval, targets_eval = data_eval.cuda(), targets_eval.cuda()

        is_classification = isinstance(targets_eval, torch.Tensor)
        data_eval = data_eval.cuda()
        n_samples = data_eval.shape[0]
        if is_classification:
            targets_eval = targets_eval.cuda()
        
        step_total += 1

        if args.evaluation == 'robust':
            loss_eval_wrapper = ComputeLossWrapper(
                embedding_orig=None, embedding_text_labels_norm=embedding_text_labels_norm,
                reduction='none', loss='ce', logit_scale=100.
            )
            data_eval_adv = apgd(
                model=attack_model,
                loss_fn=loss_eval_wrapper,
                x=data_eval,
                y=targets_eval,
                norm=args.norm,
                eps=args.eps,
                n_iter=50,
                initial_stepsize=0.05 * args.eps if args.clean_weight > 0 else None,
                verbose=False
            )
        elif args.evaluation == 'simple':
            loss_eval_wrapper = ComputeLossWrapper(
                embedding_orig=None, embedding_text_labels_norm=embedding_text_labels_norm,
                reduction='none', loss='ce', logit_scale=100.
            )
            eval_adv_targets = [] 
            for target in targets_eval:
                classes_kept = list(classes.keys())
                classes_kept.remove(int(target.item()))
                eval_adv_targets.append(int(np.random.choice(classes_kept)))
            eval_adv_targets = torch.tensor(eval_adv_targets).to(targets.device)

            data_eval_adv = pgd( 
                forward=attack_model,
                loss_fn=loss_inner_wrapper,
                data_clean=data_eval,
                targets=eval_adv_targets,
                norm=args.norm,
                eps=args.eps,
                iterations=args.iterations_adv,
                stepsize=args.stepsize_adv,
                output_normalize=args.output_normalize,
                perturbation=torch.zeros_like(data_eval).uniform_(-args.eps, args.eps).requires_grad_(True),
                mode='min',
                verbose=False
                ) 
        else:
            raise ValueError('invalid evaluation mode')
        
        with torch.no_grad():
            embedding_adv_eval_norm = model(data_eval_adv, output_normalize=True)  # we set output_normalize to True
            embedding_eval_norm = model(data_eval, output_normalize=True)
            embedding_orig_eval_norm = model_orig(vision=data_eval, output_normalize=True)
            image_cos_sim_eval_adv_list.append(F.cosine_similarity(embedding_adv_eval_norm, embedding_eval_norm, dim=1).mean())                 
            image_cos_sim_eval_clean_list.append(F.cosine_similarity(embedding_eval_norm, embedding_orig_eval_norm, dim=1).mean()) 

            # for embedding_text in [embedding_text_labels_norm,]: 
            # Compute logits (already done in your code)
            logits_eval_adv = embedding_adv_eval_norm @ embedding_text_labels_norm
            racc_eval = compute_acc(logits_eval_adv, targets_eval)

            logits_eval = embedding_eval_norm @ embedding_text_labels_norm
            acc_eval = compute_acc(logits_eval, targets_eval)

            # Extract the class embeddings based on targets_eval
            true_class_embeddings = embedding_text_labels_norm[:, targets_eval]  # Shape: [512, batch_size]

            # Now compute cosine similarity based on true class embeddings for each sample in the batch
            cos_sim_eval = F.cosine_similarity(embedding_eval_norm, true_class_embeddings.T, dim=1)
            cos_sim_eval_clean = F.cosine_similarity(embedding_adv_eval_norm, true_class_embeddings.T, dim=1)

            # Append values to the corresponding lists
            racc_eval_list.append(racc_eval)
            acc_eval_list.append(acc_eval)
            cos_sim_eval_list.append(cos_sim_eval.mean())  
            cos_sim_eval_clean_list.append(cos_sim_eval_clean.mean())   
        
        del data_eval_adv, data_eval, targets_eval, embedding_adv_eval_norm, logits_eval_adv, embedding_eval_norm, logits_eval

        if step_total >= args.num_eval_batches:
            break

    racc_mean = sum(racc_eval_list) / len(racc_eval_list)
    acc_mean = sum(acc_eval_list) / len(acc_eval_list)
    cos_sim_mean = sum(cos_sim_eval_list) / len(cos_sim_eval_list)
    cos_sim_clean_mean = sum(cos_sim_eval_clean_list) / len(cos_sim_eval_clean_list)

    image_cos_sim_eval_adv_mean = sum(image_cos_sim_eval_adv_list) / len(image_cos_sim_eval_adv_list)
    image_cos_sim_eval_clean_mean = sum(image_cos_sim_eval_clean_list) / len(image_cos_sim_eval_clean_list)
 

    if args.checkpoint_dir is None:
        temp_path = '/YOUR_ROOT_PATH/checkpoints/final/ViT-B-32-imagenet/ViT-B-32-pretrained'
        output_path = f"{temp_path}/ood_results/outputs_{args.corruption_level}.txt"  
    else: 
        output_path = f"{args.checkpoint_dir}/ood_results/outputs_{args.corruption_level}.txt"  

    with open(output_path, "a") as f: 
        f.write(f"corruption_level: {args.corruption_level}:\n")
        f.write(f"Imagenet-C/{args.corruption_type}:\n")
        f.write(f"Average racc_eval: {racc_mean}\n")
        f.write(f"Average acc_eval: {acc_mean}\n")
        f.write(f"Average text & adv image cos-sim: {cos_sim_mean}\n")
        f.write(f"Average text & clean image cos-sim: {cos_sim_clean_mean}\n")
        f.write(f"Average image-image cos-sim (adv): {image_cos_sim_eval_adv_mean}\n")
        f.write(f"Average image-image cos-sim (clean): {image_cos_sim_eval_clean_mean}\n")
        f.write('*'*30)  # add a newline for separation
        
    print(f"corruption_level: {args.corruption_level}:\n")
    print(f'Imagenet-C/{args.corruption_type}:')
    print(f'Average racc_eval: {racc_mean}')
    print(f'Average acc_eval: {acc_mean}')
    print(f'Average text & adv image cos-sim: {cos_sim_mean}')
    print(f'Average text & clean image cos-sim: {cos_sim_clean_mean}')
    print(f'Average image-image cos-sim (adv): {image_cos_sim_eval_adv_mean}')
    print(f'Average image-image cos-sim (clean): {image_cos_sim_eval_clean_mean}')
    print('*'*30)  
    torch.cuda.empty_cache() 
    return step_total

@torch.no_grad()
def compute_acc(logits, targets):
    preds_clean = logits.max(dim=1)[1].detach()
    acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100
    return acc

def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale, 
                 embedding_text_labels_norm=None, reduction='mean', 
                 rho=0.0, weights=None):
    if loss_str == 'l2':  # For Fair
        assert embedding_orig is not None, "Reference logits must be provided for L2 divergence."
        # Handle tuple input
        embedding = embedding[0] if isinstance(embedding, tuple) else embedding
        embedding_orig = embedding_orig[-1] if isinstance(embedding_orig, tuple) else embedding_orig

        # Compute l2 loss
        loss, normalized_difference = l2(out=embedding, targets=embedding_orig, reduction=reduction, rho=rho, weights=weights) 
        # TODO : add for other losses. 

    elif loss_str == 'linf':  # For L-infinity norm
        assert embedding_orig is not None, "Reference logits must be provided for L-infinity divergence."
        # Handle tuple input
        embedding = embedding[0] if isinstance(embedding, tuple) else embedding
        embedding_orig = embedding_orig[-1] if isinstance(embedding_orig, tuple) else embedding_orig

        # Compute l-infinity loss
        loss, normalized_difference = linf(out=embedding, targets=embedding_orig, reduction=reduction, rho=rho, weights=weights)

    elif loss_str == 'ce':  # For TeCoA
        # Handle tuple input
        embedding = embedding[0] if isinstance(embedding, tuple) else embedding
        embedding_orig = embedding_orig[-1] if isinstance(embedding_orig, tuple) else embedding_orig

        # Compute logits and cross-entropy loss
        logits = embedding @ (logit_scale * embedding_text_labels_norm)
        loss, normalized_difference = ce(
            logits=logits,
            targets=targets,
            reduction=reduction,
            rho=rho,
            weights=weights
        )
    
    elif loss_str == 'kl':  # For KL divergence
        assert embedding_orig is not None, "Reference logits must be provided for KL divergence."
        # Handle tuple input
        embedding = embedding[0] if isinstance(embedding, tuple) else embedding
        embedding_orig = embedding_orig[-1] if isinstance(embedding_orig, tuple) else embedding_orig

        # Compute logits for KL divergence
        logits = embedding @ (logit_scale * embedding_text_labels_norm)
        ref_logits = embedding_orig @ (logit_scale * embedding_text_labels_norm)

        loss, normalized_difference = kl_div(
            ref_logits=ref_logits,
            out_logits=logits,
            reduction=reduction,
            rho=rho,
            weights=weights
        )
    else:
        raise ValueError(f"Loss type '{loss_str}' not supported.")
    
    return loss, normalized_difference

def l2(out, targets, reduction='none', rho=0.0, weights=None):
    """
    Compute L2 loss according to the equation:
    max_φ min_θ E_x[ max_δ ||φ_θ(x + δ) - φ_ref(x)||_2^2 
                     + λ_φ(x)(||φ_θ(x) - φ_org(x)||_2^2 - ρ ||φ_org(x)||_2^2) ]
    
    Args:
        out: Current model embeddings (φ_θ)
        targets: Original model embeddings (φ_org or φ_ref)
        reduction: Type of reduction to apply ('none', 'mean')
        rho: ρ parameter controlling the regularization term
        weights: λ_φ(x) weights from lambda network
    """
    assert out.shape == targets.shape, f"Shape mismatch: {out.shape} != {targets.shape}"
    assert out.ndim == 2, f"Input tensors must be 2D: {out.shape}"
    assert out.size(0) > 1, "Batch size must be greater than 1."

    normalized_difference = ((out - targets) ** 2).sum(dim=1) / ((targets) ** 2).sum(dim=-1)
    squared_error = (normalized_difference - rho) * ((targets) ** 2).sum(dim=-1)

    if weights is not None:
        assert weights.shape[0] == squared_error.shape[0], f"Weights shape mismatch: {weights.shape} != {squared_error.shape}"
        if args.kkt_helper:
            squared_error = squared_error * squared_error.relu().detach() #squared_error.relu()
        squared_error = squared_error * weights

    # Apply reduction
    if reduction == 'mean':
        return squared_error.mean(), normalized_difference
    elif reduction == 'none':
        return squared_error, normalized_difference
    else:
        raise ValueError(f"Unsupported reduction: {reduction}")

def linf(out, targets, reduction='none', rho=0.0, weights=None):
    """
    Compute L-infinity loss according to the equation:
    max_φ min_θ E_x[ max_δ ||φ_θ(x + δ) - φ_ref(x)||_∞ 
                     + λ_φ(x)(||φ_θ(x) - φ_org(x)||_∞ - ρ ||φ_org(x)||_∞) ]
    
    This can be used with either scalar or vector Lagrangian multipliers.
    
    Args:
        out: Current model embeddings (φ_θ)
        targets: Original model embeddings (φ_org or φ_ref)
        reduction: Type of reduction to apply ('none', 'mean')
        rho: ρ parameter controlling the regularization term
        weights: λ_φ(x) weights from lambda network - can be scalar per sample or vector per feature
    """
    assert out.shape == targets.shape, f"Shape mismatch: {out.shape} != {targets.shape}"
    assert out.ndim == 2, f"Input tensors must be 2D: {out.shape}"
    assert out.size(0) > 1, "Batch size must be greater than 1."

    # Calculate absolute differences for each element in the embedding
    abs_diff = torch.abs(out - targets)
    
    if args.lagrangian_type == 'vector':
        # Compute normalized element-wise differences
        # Each element has its own constraint
        norm_targets = torch.abs(targets)
        # Avoid division by zero
        norm_targets = torch.clamp(norm_targets, min=1e-6)
        normalized_difference = abs_diff / norm_targets
        
        # Calculate element-wise constraints: (|φ_θ - φ_org|/|φ_org| - ρ) * |φ_org|
        element_constraints = (normalized_difference - rho) * norm_targets
        
        if weights is not None:
            # For vector Lagrangian, weights should have shape [batch_size, embedding_dim]
            assert weights.shape == element_constraints.shape, f"Vector weights shape mismatch: {weights.shape} != {element_constraints.shape}"
            
            if args.kkt_helper:
                # Apply KKT helper element-wise
                element_constraints = element_constraints * element_constraints.relu().detach()
            
            # Apply weights element-wise
            element_constraints = element_constraints * weights
            
        # Aggregate across the embedding dimension to get per-sample loss
        if reduction == 'mean':
            return element_constraints.mean(), normalized_difference
        elif reduction == 'none':
            # Sum over the embedding dimension, keeping batch dimension
            return element_constraints.sum(dim=1), normalized_difference
        else:
            raise ValueError(f"Unsupported reduction: {reduction}")
            
    else:  # 'scalar' (default)
        # Find max absolute difference for each sample (l-infinity norm)
        max_abs_diff, _ = torch.max(abs_diff, dim=1)
        
        # Get l-infinity norm of target for normalization
        max_abs_target, _ = torch.max(torch.abs(targets), dim=1)
        max_abs_target = torch.clamp(max_abs_target, min=1e-6)  # Avoid division by zero
        
        # Normalized l-infinity difference
        normalized_difference = max_abs_diff / max_abs_target
        
        # Apply constraint: (||φ_θ - φ_org||_∞/||φ_org||_∞ - ρ) * ||φ_org||_∞
        inf_error = (normalized_difference - rho) * max_abs_target
        
        if weights is not None:
            assert weights.shape[0] == inf_error.shape[0], f"Scalar weights shape mismatch: {weights.shape} != {inf_error.shape}"
            
            if args.kkt_helper:
                inf_error = inf_error * inf_error.relu().detach()
                
            inf_error = inf_error * weights
        
        # Apply reduction
        if reduction == 'mean':
            return inf_error.mean(), normalized_difference
        elif reduction == 'none':
            return inf_error, normalized_difference
        else:
            raise ValueError(f"Unsupported reduction: {reduction}")

def ce(logits, targets, reduction='mean', rho=0.0, weights=None):
    """
    Compute Cross Entropy loss with optional weighting and regularization.
    This is an alternative formulation to KL divergence that directly works with logits.
    
    Args:
        logits: Model output logits
        targets: Ground truth class labels
        reduction: Type of reduction to apply ('none', 'mean', 'sum')
        rho: Regularization parameter
        weights: Optional weights for each sample in the batch
    """
    assert logits.shape[0] == targets.shape[0], f"Shape mismatch: {logits.shape} vs {targets.shape}"
    assert logits.size(0) > 1, "Batch size must be greater than 1."
    assert reduction in ['none', 'mean', 'sum'], f"Unsupported reduction: {reduction}"

    # Compute raw cross-entropy loss
    ce_loss = F.cross_entropy(logits, targets, reduction='none')

    # Apply delta adjustment
    ce_loss = ce_loss - rho

    # Apply weights
    if weights is not None:
        assert weights.shape[0] == ce_loss.shape[0], f"Weights shape mismatch: {weights.shape} != {squared_error.shape}"
        ce_loss = ce_loss * weights

    # Reduce the loss
    if reduction == 'mean':
        ce_loss = ce_loss.mean(), None
    elif reduction == 'sum':
        ce_loss = ce_loss.sum(), None

    return ce_loss, None

def kl_div(ref_logits, out_logits, reduction='mean', rho=0.0, weights=None):
    """
    Compute KL divergence loss according to the equation:
    max_φ min_θ E_x[ max_δ D(p_org(·|x) || p_θ(·|x + δ)) 
                     + λ_φ(x)(D(p_org(·|x) || p_θ(·|x)) - ρ H(p_org(·|x))) ]
    
    Args:
        ref_logits: Original model logits (p_org)
        out_logits: Current model logits (p_θ)
        reduction: Type of reduction to apply
        rho: ρ parameter controlling entropy weight
        weights: λ_φ(x) weights from lambda network
    """
    assert ref_logits.shape == out_logits.shape, f"Shape mismatch: {ref_logits.shape} vs {out_logits.shape}"
    assert ref_logits.size(0) > 1, "Batch size must be greater than 1."
    assert reduction in ['none', 'mean', 'sum'], f"Unsupported reduction: {reduction}"

    # Convert logits to log probabilities and probabilities
    ref_log_probs = F.log_softmax(ref_logits, dim=1)
    out_log_probs = F.log_softmax(out_logits, dim=1)
    ref_probs = ref_log_probs.exp()
    
    # Compute KL divergence: D(p_org || p_θ)
    kl_loss = F.kl_div(out_log_probs, ref_probs, reduction='none', log_target=False).sum(dim=1)
    
    # Compute entropy of original distribution: H(p_org)
    entropy = -(ref_probs * ref_log_probs).sum(dim=1)  # Using log_probs directly
    
    # Compute constrained loss: D(p_org || p_θ) - ρ H(p_org)
    constrained_loss = kl_loss - rho * entropy

    # Apply weights (λ_φ) if provided
    if weights is not None:
        assert weights.shape[0] == constrained_loss.shape[0], f"Weights shape mismatch: {weights.shape} != {constrained_loss.shape}"
        if args.kkt_helper:
            # Apply KKT helper similar to l2 loss
            constrained_loss = constrained_loss * constrained_loss.relu().detach()
        constrained_loss = constrained_loss * weights

    # Apply reduction
    if reduction == 'mean':
        return constrained_loss.mean(), None
    elif reduction == 'sum':
        return constrained_loss.sum(), None
    else:  # 'none'
        return constrained_loss, None

def get_dataset(dataset_name, transform, imagenet_root=None):
    global classes
    if imagenet_root is None:
        imagenet_root = args.imagenet_root

    if dataset_name == 'imagenet':
        dataset = ImageNetDataset(
            root=imagenet_root + '/train',
            transform=transform,
        )
        dataset_eval = ImageNetDataset(
            root=imagenet_root + '/val',
            transform=transform,
        )
        classes = IMAGENET_1K_CLASS_ID_TO_LABEL
    if dataset_name == 'imagenet100':
        dataset = ImageNetDataset(
            root=imagenet_root + '/train',
            transform=transform,
        )
        dataset_eval = ImageNetDataset(
            root=imagenet_root + '/val',
            transform=transform,
        )
        classes = IMAGENET_100_CLASS_ID_TO_LABEL
    if dataset_name == 'imagenet_c':
        dataset = ImageNetDataset(
            root=imagenet_root,
            transform=transform,
        )
        dataset_eval = ImageNetDataset(
            root=imagenet_root,
            transform=transform,
        )
        classes = IMAGENET_1K_CLASS_ID_TO_LABEL

    elif dataset_name == 'cifar10':
        dataset = CIFAR10(
            root='/mnt/raid10/ak-research-01/ak-research-01/codes/.cache/cifar', train=True, download=True,
            transform=transform,    
        )
        dataset_eval = CIFAR10(
            root='/mnt/raid10/ak-research-01/ak-research-01/codes/.cache/cifar', train=False, download=True,
            transform=transform,    
        )
        classes = dataset.classes  
        classes = dict(
            zip(range(len(classes)), classes))
    elif dataset_name == 'cifar100':
        dataset = CIFAR100(
            root='./cifar', train=True, download=True,
            transform=transform,    
        )
        dataset_eval = CIFAR100(
            root='./cifar', train=False, download=True,
            transform=transform,    
        )
        classes = dataset.classes  
        classes = dict(
            zip(range(len(classes)), classes))
    elif dataset_name == 'segment_anything':
        dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=transform)
        print(dataset.__len__())
    elif dataset_name == 'coco':
        if os.path.exists('/mnt/datasets/coco'):
            image_dir_path = '/mnt/datasets/coco/train2017'
            annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json'
        elif os.path.exists('/mnt/lustre'):
            image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017'
            annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json'
        else:
            raise ValueError('COCO dataset not found')
        dataset = COCOFlickrDataset(
            image_dir_path=image_dir_path,
            annotations_path=annotations_path,
            transform=transform
        )
    elif dataset_name == 'eurosat':
        from torchvision.datasets import EuroSAT
        dataset = EuroSAT(
            root='./eurosat', download=True,
            transform=transform,
        )
        dataset_eval = EuroSAT(
            root='./eurosat', download=False,
            transform=transform,
        )
        classes = dataset.classes
        classes = dict(
            zip(range(len(classes)), classes))
    return dataset, dataset_eval


if __name__ == '__main__':
    # set seeds
    torch.manual_seed(0)
    np.random.seed(0)

    # Parse command-line arguments
    args = parser.parse_args()
    args.eps /= 255
    args.stepsize_adv /= 255
    # make sure there is no string in args that should be a bool
    assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}'
    assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq'

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print(f'Number of GPUs available: {num_gpus}')
    else:
        print('No multiple GPUs available.')
    
    # run
    main(args)
    
    sys.stdout.close()
    sys.stderr.close()
