from __future__ import print_function
import cv2
import numpy as np
import argparse, os, time, random
from tqdm import tqdm
import logging
import torch, torchvision
import torch.backends.cudnn as cudnn
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from torchvision.datasets import *
from replace import clip
from models import prompters
from models.prompters import TokenPrompter,NullPrompter
from models.model import *
from attacks import *
import copy
from utils import accuracy, AverageMeter, ProgressMeter, save_checkpoint
from utils import cosine_lr, convert_models_to_fp32, refine_classname
from utils import load_train_dataset, load_val_datasets, get_text_prompts_train, \
    get_text_prompts_val

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from attention_map import *

def parse_option():
    parser = argparse.ArgumentParser('Adapting CLIP for attack detection')
    parser.add_argument('--print_freq', type=int, default=50, help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50, help='save frequency')
    parser.add_argument('--validate_freq', type=int, default=2, help='validate frequency')
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--epochs', type=int, default=10, help='number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
    parser.add_argument("--weight_decay", type=float, default=0, help="weight decay")
    parser.add_argument("--warmup", type=int, default=1000, help="number of steps to warmup for")
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--patience', type=int, default=1000)
    parser.add_argument('--model', type=str, default='clip')
    parser.add_argument('--arch', type=str, default='vit_b32')
    parser.add_argument('--root', type=str, default='./data', 
                        help='dataset')
    parser.add_argument('--dataset', type=str, default='cifar10',help='dataset')
    parser.add_argument('--image_size', type=int, default=32, help='image size')
    parser.add_argument('--seed', type=int, default=0, help='seed for initializing training')
    parser.add_argument('--model_dir', type=str, default='./save/models', 
                        help='path to save models')
    parser.add_argument('--filename', type=str, default=None, 
                        help='filename to save')
    parser.add_argument('--resume', type=str, default=None,
                        help='path to resume from checkpoint')
    parser.add_argument('--Distance_metric', type=str, default='l2', choices=['cos', 'l2', 'l1'],
                        help='Select the distance measure in the loss function')
    parser.add_argument('--cos', type=float, default=1, help='cosine loss')
    parser.add_argument('--L0', type=float, default=0.05, help='L0 loss')
    parser.add_argument('--testdata', type=str, nargs='+')
    args = parser.parse_args()

    args.filename = '{}_{}_{}_{}_lr-{}_decay-{}_bsz-{}_warmup-{}_cos-{}_L0-{}_distance-{}'. \
        format(args.dataset, args.model, args.arch, args.learning_rate, 
               args.weight_decay, args.batch_size, args.warmup, args.cos, 
               args.L0, args.Distance_metric)
    return args

def main():
    global best_acc1, device, logger
    best_acc1=0.0
    args = parse_option()
    device = torch.device("cuda:{}".format(args.gpu))
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    log_dir = './save/loggers/'
    os.makedirs(log_dir, exist_ok=True)
    file_handler = logging.FileHandler(os.path.join(log_dir,f'{args.filename}.log'))
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s [%(filename)s] => %(message)s")
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    args_dict = vars(args)
    for key, value in args_dict.items():
        print(f'{key}: {value}')
        logger.info(f'{key}: {value}')

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    """create model"""
    if args.adaptation_method == 'VPT':
        add_prompt_len = args.add_prompt_size
    else:
        add_prompt_len = 0
    print(" create model")
    model, preprocess = clip.load('ViT-B/32', device, jit=False, prompt_len=add_prompt_len)

    convert_models_to_fp32(model)
    model = model.to(device)
    frozen_model = copy.deepcopy(model).to(device)
    
    model.eval()
    frozen_model.eval() 
    
    """define criterion and optimizer"""
    if args.adaptation_method == 'VPT':
        prompter = prompters.__dict__[args.method](args).to(device)
        add_prompter = TokenPrompter(args.add_prompt_size).to(device)
        optimizer = torch.optim.SGD(list(prompter.parameters()) + list(add_prompter.parameters()),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    elif args.adaptation_method == 'Adam':
        prompter = NullPrompter().to(device)
        add_prompter = TokenPrompter(0).to(device)
        optimizer = torch.optim.Adam(model.visual.parameters(),
                             lr=args.learning_rate,
                             betas=(0.9, 0.999),  # Default values
                             eps=1e-8,  # Default value
                             weight_decay=args.weight_decay)
    else:
        prompter = NullPrompter().to(device)
        add_prompter = TokenPrompter(0).to(device)
        if args.last_num_ft == 0:
            optimizer = torch.optim.SGD(model.visual.parameters(),
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(list(model.visual.parameters())[-args.last_num_ft:],
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    
    """Load the pre-trained model"""
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']

            if 'vision_encoder_state_dict' in checkpoint.keys():
                model.visual.load_state_dict(checkpoint['vision_encoder_state_dict'], strict=False)
            else:
                prompter.load_state_dict(checkpoint['state_dict'])
                add_prompter.load_state_dict(checkpoint['add_prompter'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            logger.info("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    template = 'This is a photo of a {}'
    print(f'template: {template}')


    """load training dataset"""
    train_dataset = load_train_dataset(args)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
                               shuffle=True, sampler=train_sampler)
    """load val dataset(s)"""
    if args.testdata is None:
        val_dataset_name = ['tinyImageNet']
    else:
        val_dataset_name = args.testdata

    torch.manual_seed(42)
    test_size = len(testset_full)
    val_size = test_size // 2
    test_size = test_size - val_size
    
    valset, testset = torch.utils.data.random_split(testset, [val_size, test_size])
    
    valloader = torch.utils.data.DataLoader(
        valset, batch_size=100, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)

    scaler = GradScaler()
    total_steps = len(train_loader) * args.epochs
    scheduler = cosine_lr(optimizer, args.learning_rate, args.warmup, total_steps)

    cudnn.benchmark = True
    args.model_folder = os.path.join(args.model_dir, args.filename)
    if not os.path.isdir(args.model_folder):
        os.makedirs(args.model_folder)

    epochs_since_improvement = 0
   

    """training"""
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_loader, texts_train, model,frozen_model, prompter, add_prompter, optimizer, scheduler,
              scaler, epoch,  args)
        
        # evaluate on validation set
        if epoch % args.validate_freq == 0:
            acc1_mean = validate(val_loader_list, val_dataset_name, texts_list, model,frozen_model,optimizer, device,
                                 prompter, add_prompter, args)
            
        # remember best acc@1 and save checkpoint
        is_best = acc1_mean > best_acc1
        best_acc1 = max(acc1_mean, best_acc1)

        save_checkpoint({
            'epoch': args.start_epoch + epoch + 1,
            'state_dict': prompter.state_dict(),
            'add_prompter': add_prompter.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
            'vision_encoder_state_dict':model.visual.state_dict(),
        }, args, is_best=is_best)

        if is_best:
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1
            print(f"There's no improvement for {epochs_since_improvement} epochs.")
            logger.info(f"There's no improvement for {epochs_since_improvement} epochs.")
            if epochs_since_improvement >= args.patience:
                print("The training halted by early stopping criterion.")
                logger.info("The training halted by early stopping criterion.")
                break

"""train function"""
def train(train_loader, texts, model,frozen_model, prompter, add_prompter,
          optimizer, scheduler, scaler, epoch,  args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch: [{}]".format(args.start_epoch + epoch))

    """switch to train mode"""
    prompter.train()
    add_prompter.train()
    model.visual.train()
    num_batches_per_epoch = len(train_loader)

    alpha = args.train_stepsize
    attack_iters = args.train_numsteps

    end = time.time()
    best_sharpness_flag = False
    for i, (images, target) in enumerate(tqdm(train_loader)):

        # measure data loading time
        data_time.update(time.time() - end)
        BATCH_SIZE = images.size(0)

        # adjust learning rate
        step = num_batches_per_epoch * epoch + i
        scheduler(step)

        optimizer.zero_grad()

        images = images.to(device)
        target = target.to(device)
        text_tokens = clip.tokenize(texts).to(device)
        

        # with automatic mixed precision
        with autocast():
            """Build adversarial example"""
            if not args.VPbaseline:
                delta = attack_pgd(prompter, model,add_prompter,images,
                                target, text_tokens, alpha, attack_iters, 'l_inf',
                                device=device, args=args, epsilon=args.train_eps)
                tmp = clip_img_preprocessing(images + delta,device)
            else:
                tmp = clip_img_preprocessing(images,device)

            prompted_images = prompter(tmp)
            clean_images = prompter(clip_img_preprocessing(images,device))
            output_org, _ , text_features,img_features= img_embedding(model, clean_images, text_tokens, target, device, prompt_token)
            loss_cos_org ,loss_L0_org,_,logits_L0_org_raw =criterion_L0(output_org, target, img_features, text_features)
            output, _ , text_features,img_features= img_embedding(model, prompted_images, text_tokens, target, device, prompt_token)
      
            loss_cos ,loss_L0,logits_L0,logits_L0_raw =criterion_L0(output, target, img_features, text_features)

            loss = args.cos*(loss_cos_org) + args.L0*(loss_L0_org)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
        scaler.update()


        model.logit_scale.data = torch.clamp(model.logit_scale.data, 0, 4.6052)   
        # measure accuracy

        acc1 = accuracy(args.cos*output_org+args.L0*logits_L0_org_raw, target, topk=(1,))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0].item(), images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time() 

        if i % args.print_freq == 0:
            entries = progress.display(i)
            logger.info(entries)

            logger.info("Loss function is: loss= loss_L0_org")

            logger.info("cos Loss: %f, L0 Loss: %f,L0 Loss org %f", loss_cos, loss_L0,loss_L0_org)

    save_checkpoint({
        'epoch': args.start_epoch + epoch + 1,
        'state_dict': prompter.state_dict(),
        'add_prompter': add_prompter.state_dict(),
        'best_acc1': best_acc1,
        'optimizer': optimizer.state_dict(),
        'vision_encoder_state_dict':model.visual.state_dict(),
        }, args)
    return losses.avg, top1.avg


def validate(val_loader_list, val_dataset_name, texts_list, model,frozen_model,optimizer, device, args):
    dataset_num = len(val_loader_list)
    acc_all = []

    test_stepsize = args.test_stepsize

    for cnt in range(dataset_num):

        val_loader = val_loader_list[cnt]
        texts = texts_list[cnt]
        dataset_name = val_dataset_name[cnt]

        binary = ['PCAM', 'hateful_memes']
        attacks_to_run=['apgd-ce', 'apgd-dlr']
        if dataset_name in binary:
            attacks_to_run=['apgd-ce']
            
        batch_time = AverageMeter('Time', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1_org = AverageMeter('Original Acc@1', ':6.2f')
        
        progress = ProgressMeter(
            len(val_loader),
            [batch_time, losses, top1_org],
            prefix=dataset_name + '_Validate: ')

        model.eval()
        model.zero_grad()
        frozen_model.eval()

        end = time.time()
        for i, (images, target) in enumerate(tqdm(val_loader)):
            images = images.to(device)
            target = target.to(device)
            text_tokens = clip.tokenize(texts).to(device)

            with autocast():

                # compute output
                with torch.no_grad():
                    """clean images"""
                    output_org, _ , text_features_org,img_features_org= img_embedding(model, clip_img_preprocessing(images,device), text_tokens, target, device, None)
                    
                    loss_cos_org ,loss_L0_org,logits_L0_org,logits_L0_raw_org=criterion_L0(output_org, target, img_features_org, text_features_org)

                    acc1 = accuracy(args.cos*output_org+ args.L0*logits_L0_raw_org, target, topk=(1,))
                    top1_org.update(acc1[0].item(), images.size(0))

                    loss = args.cos*(loss_cos_org) + args.L0*(loss_L0_org)
                        
                # torch.cuda.empty_cache()
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                entries = progress.display(i)
                logger.info(entries)
                if args.debug:
                    break
        torch.cuda.empty_cache()

        acc_all.append(top1_org.avg)
    return np.mean(acc_all)



def criterion_L0(output, target, img_features, text_features):
    #normalize
    img_features = F.normalize(img_features, dim=1)
    text_features = F.normalize(text_features, dim=1)
    CrossEntropyLoss = torch.nn.CrossEntropyLoss().to(device)
    loss_cos = CrossEntropyLoss(output, target)
    
    bsz = img_features.shape[0]
    num_classes = text_features.shape[0]
    dim_size= img_features.shape[1]
    # Reshape for broadcasting
    img_features_expanded = img_features.unsqueeze(1)  # (bsz, 1, 512)
    text_features_expanded = text_features.unsqueeze(0)  # (1, 200, 512)
    
    # Calculate absolute difference for all pairs at once
    # diff: (bsz, 200, 512)
    diff = torch.abs(img_features_expanded - text_features_expanded)
    tau=0.75
    thresholds = tau*diff.mean(dim=2, keepdim=True)

    temperature=0.5
    
    smooth_indicator = torch.sigmoid((diff - thresholds) / temperature)

    l0_approximation = torch.sum(smooth_indicator, dim=2)

    scale_factor = 1/dim_size  
    similarity_scores = (dim_size - l0_approximation) * scale_factor
    
    # After calculating similarity_scores but before applying temperature
    margin = 0.2  # Adjust as needed
    batch_size = similarity_scores.size(0)
    class_indices = target.view(-1, 1)

    # Create one-hot encoding
    one_hot = torch.zeros_like(similarity_scores)
    one_hot.scatter_(1, class_indices, 1)

    # Apply margin to correct classes only
    similarity_with_margin = similarity_scores + (margin * one_hot)

    # Then apply temperature
    t=50
    logits_L0 = similarity_with_margin * t  # Keep high temperature
    logits_L0_raw = similarity_scores * t  # Keep high temperature

    loss_L0 = CrossEntropyLoss(logits_L0, target)
    loss_L0_raw = CrossEntropyLoss(logits_L0_raw, target)

    
    return loss_cos ,loss_L0,loss_L0_raw,logits_L0,logits_L0_raw

def img_embedding(clip_model, images, text_tokens, target, device):

  
    img_embed, scale_text_embed = clip_model(images, text_tokens, None)
    logits_per_image = img_embed @ scale_text_embed.t()
    logits_per_text = scale_text_embed @ img_embed.t()
    return logits_per_image, logits_per_text, scale_text_embed,img_embed


def criterion_KL(output, target, img_features, text_features):
    #normalize
    img_features = F.normalize(img_features, dim=1) #(bsz, 512)
    text_features = F.normalize(text_features, dim=1) #(200, 512)
    CrossEntropyLoss = torch.nn.CrossEntropyLoss().to(device)
    loss_cos = CrossEntropyLoss(output, target)


    # Reshape for broadcasting
    img_features_expanded = img_features.unsqueeze(1)  # (bsz, 1, 512)
    text_features_expanded = text_features.unsqueeze(0)  # (1, 200, 512)

    # Calculate KL distances
    kl_div = compute_kl_divergence(img_features_expanded, text_features_expanded)  # (bsz, 200)

    similarity_scores = 1 - kl_div  # (bsz, 200)


    # After calculating similarity_scores but before applying temperature
    margin = 0.2  # Adjust as needed
    batch_size = similarity_scores.size(0)
    class_indices = target.view(-1, 1)

    # Create one-hot encoding
    one_hot = torch.zeros_like(similarity_scores)
    one_hot.scatter_(1, class_indices, 1)

    # Apply margin to correct classes only
    print(f"Before margin - min: {similarity_scores.min().item()}, max: {similarity_scores.max().item()}")
    similarity_with_margin = similarity_scores - (margin * one_hot)
    print(f"After margin - min: {similarity_with_margin.min().item()}, max: {similarity_with_margin.max().item()}")
    # Then apply temperature
    t=50
    logits_KL = similarity_with_margin * t  # Keep high temperature
    logits_KL_raw = similarity_scores * t  # Keep high temperature
    # Apply softmax to get probability distribution over classes

    loss_KL = CrossEntropyLoss(logits_KL, target)
    loss_KL_raw = CrossEntropyLoss(logits_KL_raw, target)
    print(f"Logits min: {logits_KL.min().item()}, max: {logits_KL.max().item()}, mean: {logits_KL.mean().item()}")


    return loss_cos ,loss_KL,loss_KL_raw,logits_KL,logits_KL_raw

def compute_kl_divergence(p, q):

    p_dist = F.softmax(p, dim=2)  # (bsz, 1, 512)
    q_dist = F.softmax(q, dim=2)  # (1, 200, 512)
    
    # Add small epsilon to avoid log(0)
    epsilon = 1e-10
    p_dist = torch.clamp(p_dist, min=epsilon, max=1.0)
    q_dist = torch.clamp(q_dist, min=epsilon, max=1.0)
    
    p_expanded = p_dist.repeat(1, q.size(1), 1)  # (bsz, 200, 512)
    
    # Calculate KL divergence: sum(p_i * log(p_i / q_i)) along feature dimension
    kl_div = torch.sum(p_expanded * torch.log(p_expanded / q_dist), dim=2)  # (bsz, 200)
    
    return kl_div



if __name__ == '__main__':
    main()
