import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from contextlib import nullcontext # Import the null context manager

from torch.nn.utils import clip_grad_norm_

from utils.utils import accuracy, AverageMeter, ProgressMeter, masking, AttentionAnalyzer

from utils.vis_utils import visualize_masking, visualize_pred, imagenet_label
from utils.analyze_utils import LoggingAttention
from open_clip.loss import ClipLoss

analyzer = AttentionAnalyzer(hold_raw=False)

def get_image_logits(image_features, text_features, logit_scale, **kwargs):
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    logits_per_image = logit_scale * image_features @ text_features.t()
    return logits_per_image




def model_forward(model, images, texts=None, tokenizer=None, drop_tokens=0.0, model_name='clip', dataset='cifar100', patch_size=32, visualize=False, freeze_backbone=False, target=None, attr=None, mask=None, old_att=None, return_features=False, is_torchvision=False, regularization=None, return_loss=False):
    loss = None
    attentions, cross_attentions, normed_cross_attentions, indices = None, None, None, None
    log_prob = None
    feature = None
    if 'clip' in model_name:
        if regularization == 'FLYP':
            output = model(images, texts)
        else:
            output = model.get_logits(images, texts)
        if texts is not None: # contrastive loss
            return output
    elif 'resnet' in model_name or is_torchvision:
        if hasattr(model, 'prompter'):
            images = model.prompter(images)
        output = model(images)
    elif model_name == 'dinov2' :
        output = model(images)
        output = output['logits']
    elif model_name == 'align' or 'siglip' in model_name:
        if regularization == 'FLYP':
            output = model(pixel_values=images, input_ids=texts, return_loss=return_loss)
            loss = output.loss
            output = output.logits_per_image
        else:
            output = model(pixel_values=images)
            output = output[0]
    else: # normal vit
        if freeze_backbone:
            with torch.no_grad():
                output = model.forward_features(images, return_attention=True)
        else:
            output = model.forward_features(images, return_attention=True)
        if type(output) == tuple:
            output, attentions, cross_attentions, normed_cross_attentions, indices, _, log_prob = output
        else:
            attentions, cross_attentions, normed_cross_attentions, indices = None, None, None, None
            log_prob = None
        if return_features:
            feature = model.forward_head(output, pre_logits=True).detach().cpu()
        else:
            feature = None
        output = model.forward_head(output)
        
    if mask is not None and not return_loss:
        output = output[:, mask]
    return output, attentions, cross_attentions, normed_cross_attentions, indices, log_prob, feature, loss



def train(train_loader, texts, model, tokenizer, optimizer, scheduler, criterion, scaler, epoch, args, return_attention=False, mask=None, old_model=None, regularizer=None):
    """
    Run one train epoch
    """
    is_torchvision = 'in1k_orig' in args.d_pre
    device = "cuda" if torch.cuda.is_available() else "cpu"

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch: [{}]".format(epoch))

    clip_loss_fn = ClipLoss()

    # switch to train mode
    if hasattr(model, 'prompter'):
        model.eval()
        model.prompter.train()
    else:
        model.train()

    attentions = []
    attr_ids = []
    indices = []

    num_batches_per_epoch = len(train_loader)

    end = time.time()
    if texts is not None and mask is not None:
        texts = texts[mask]
    for i, data in enumerate(tqdm(train_loader)):
        if len(data) == 4:
            images, target, attr, idx = data
        elif len(data) == 3:
            images, target, attr = data
            idx = None
        else:
            images, target = data
            attr = None
            idx = None
        if args.d_pre == 'dinov2':
            images = images['pixel_values'][0]
        if args.model == 'align' or 'siglip' in args.model:
            images = images['pixel_values'][:,0]
        # measure data loading time
        data_time.update(time.time() - end)
        attr_ids.append(attr)

        # adjust learning rate
        step = num_batches_per_epoch * epoch + i
        if scheduler is not None and args.scheduler == 'cosine':
            scheduler(step)

        optimizer.zero_grad()

        images = images.to(device)
        target = target.to(device)
        if texts is not None:
            text_tokens = texts[target]
        else:
            text_tokens = None

        if type(attr) == torch.Tensor:
            attr = attr.to(device)
        # with automatic mixed precision
        autocast_context = autocast() if args.model != 'align' and 'siglip' not in args.model else nullcontext()
        with autocast_context:
            output = model_forward(model, images, text_tokens, tokenizer, args.drop_tokens, args.model, args.dataset, 
                                   patch_size=args.patch_size, freeze_backbone=args.freeze_backbone, target=target, attr=attr, 
                                   mask=mask, is_torchvision=is_torchvision, regularization=args.regularization, return_loss=True)
            if args.regularization == 'FLYP':
                if 'siglip' in args.model:
                    loss = output[-1]
                    output = output[0]
                else:
                    loss = clip_loss_fn(*output)
                    output = clip_loss_fn.get_logits(*output)[0]
            elif args.mode == 'contrastive':
                output, text_logits = output
                labels = torch.arange(len(images), dtype=torch.long, device=device)
                loss = F.cross_entropy(output, labels) + F.cross_entropy(text_logits, labels)
            else:
                if type(output) is tuple:
                    output, atts, cross_attentions, normed_cross_attentions, indices, log_prob, feature, loss = output
                if log_prob is not None:
                    loss = nn.CrossEntropyLoss(reduction='none')(output, target)
                    loss = loss + loss.detach() * log_prob.mean(-1)
                    loss = loss.mean()
                else:
                    loss = criterion(output, target)
                # nan to zero
                loss = loss.nan_to_num(0)
            
            if regularizer is not None:
                if args.mode == 'contrastive':
                    loss += regularizer(images, (output, text_logits), model, text_tokens)
                else:
                    loss += regularizer(images, output, model)
            if args.no_scaler:
                loss.backward()
                if args.grad_clip != 0:
                    clip_grad_norm_(model.parameters(), args.grad_clip)
                optimizer.step()
            else:
                scaler.scale(loss).backward()
                if args.grad_clip != 0:
                    clip_grad_norm_(model.parameters(), args.grad_clip)
                scaler.step(optimizer)
        if not args.no_scaler:
            scaler.update()
        
        # save model every iterations.
        if False: 
            torch.save(model.state_dict(), f'./{args.model_folder}/epoch{epoch}_iter{i:03}.pth')
            print(f'Saved model at epoch{epoch}_iter{i:000}.pth')

        
        if 'clip' in args.model:
            # Note: we clamp to 4.6052 = ln(100), as in the original paper.
            model.logit_scale.data = torch.clamp(model.logit_scale.data, 0, 4.6052)

        # measure accuracy
        acc1 = accuracy(output, target, topk=(1,))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0].item(), images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)


    print(f' * Training Acc@1 {top1.avg:.3f}\t Loss {losses.avg:.3f}')

    if return_attention:
        if len(attentions) > 0 and attentions[0] is not None:
            attentions = torch.cat(attentions, 1).detach().cpu()
            attr_ids = torch.cat(attr_ids, 0)
        return losses.avg, top1.avg, attentions, attr_ids
    return losses.avg, top1.avg


def validate(val_loader, texts, model, tokenizer, criterion, args, prefix='', visualize=False, return_attention=False, mask=None, old_model=None, epoch=0):
    is_torchvision = 'in1k_orig' in args.d_pre
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1_org = AverageMeter('Original Acc@1', ':6.2f')
    top1 = AverageMeter('Prompt Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1_org, top1],
        prefix='Validate: ')

    # switch to evaluation mode
    model.eval()
    failure_samples = []
    failure_targets = []

    predictions = []
    ground_truths = []
    attentions = []
    attr_ids = []
    indices = []
    key_probs = []
    outputs = []
    targets = []
    features = []
    
#    if args.regularization == 'FLYP':
#        clip_loss_fn = ClipLoss()

    if texts is not None and mask is not None:
        texts = texts[mask]

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(tqdm(val_loader)):
            if len(data) == 4:
                images, target, attr, idx = data
                attr_ids.append(attr) #[:5])
                indices.append(idx)
            elif len(data) == 3:
                images, target, attr = data
                attr_ids.append(attr) #[:5])
                idx = None
            else:
                images, target = data
                attr = None

            if args.d_pre == 'dinov2':
                images = images['pixel_values'][0]
            if args.model == 'align' or 'siglip' in args.model:
                images = images['pixel_values'][:, 0]

            
            images = images.to(device)
            target = target.to(device)
            if texts is not None and args.regularization != 'FLYP':
                if args.regularization == 'FLYP':
                    text_tokens = texts
                else:
                    text_tokens = texts[target]
            else:
                text_tokens = None

            if type(attr) == torch.Tensor:
                attr = attr.to(device)
            output = model_forward(model, images, text_tokens, tokenizer, args.drop_tokens, args.model, args.dataset, patch_size=args.patch_size, target=target, visualize=args.visualize, attr=attr, mask=mask, return_features=args.collect_features, is_torchvision=is_torchvision, regularization=args.regularization)
#a            if args.regularization == 'FLYP':
#                loss = clip_loss_fn(*output)
#                output = clip_loss_fn.get_logits(*output)[0]
            if args.mode == 'contrastive':
                output, text_logits = output
                labels = torch.arange(len(images), dtype=torch.long, device=device)
                loss = F.cross_entropy(output, labels) + F.cross_entropy(text_logits, labels)
            elif args.regularization == 'FLYP' and 'siglip' in args.model:
                output = output
                loss = output.sum() * 0
            else:
                if type(output) is tuple:
                    output, atts, cross_attentions, normed_cross_attentions, indices, log_prob, feature, loss = output
                loss = criterion(output, target)
            if args.collect_logits:
                predictions.append(output.detach().cpu())
                ground_truths.append(target.detach().cpu())

            if args.collect_failure:
                failure = output.argmax(dim=1) != target
                failure_samples.append(images[failure])
                failure_targets.append(target[failure])
            if args.collect_features:
                features.append(feature)

            targets.append(target)
            outputs.append(output)

            # measure accuracy and record loss
            losses.update(loss.item(), images.size(0))
            acc1 = accuracy(output, target, topk=(1,))
            top1.update(acc1[0].item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
        
        print(' * Original Acc@1 {top1_org.avg:.3f} Fine-Tuning Acc@1 {top1.avg:.3f}'
              .format(top1=top1, top1_org=top1_org))
        
        if len(attentions) > 0 and attentions[0] is not None:
            attentions = torch.cat(attentions, 1).cpu()
    analyzer.get_summary()
    if args.collect_logits:
        predictions = torch.cat(predictions)
        ground_truths = torch.cat(ground_truths)
    if args.collect_failure:
        predictions = torch.cat(predictions)
        ground_truths = torch.cat(ground_truths)
        failure_samples = torch.cat(failure_samples, dim=0)
        failure_targets = torch.cat(failure_targets, dim=0)
    if args.collect_features:
        features = torch.cat(features, dim=0)
    
    if return_attention:
        if len(attr_ids) > 0:
            attr_ids = torch.cat(attr_ids)
        key_probs = torch.cat(key_probs, 1)
        if len(attentions) > 0:
            attentions = torch.cat(attentions, 1)
        return top1.avg, losses.avg, (predictions, ground_truths), attentions, attr_ids, key_probs, features
    return top1.avg, losses.avg, (predictions, ground_truths), features

