
import torch
import torch.nn.functional as F

from modeling import ImageClassifier
from heads import get_classification_head
from datasets.ImageNet import SequentialImagenet
from datasets.common import get_dataloader
from datasets.registry import get_dataset

def get_params(model):
    param_vectors = [param.view(-1) for param in model.parameters()]

    flattened_params = torch.cat(param_vectors)
    return flattened_params

def get_grads(model):
    param_vectors = []
    for param in model.parameters():
        if param.grad is None:
            param_vectors.append(torch.zeros_like(param).view(-1))
        else:
            param_vectors.append(param.grad.view(-1))

    flattened_params = torch.cat(param_vectors)
    return flattened_params

def NTK_merging(args, pretrained_checkpoint, exam_datasets, is_imageNet=False):
    pretrained_model = torch.load(pretrained_checkpoint).to('cuda')

    for pp in pretrained_model.parameters():
        pp.requires_grad = True

    if is_imageNet:
        ntk_dataset = exam_datasets + ['ImageNet']
    else:
        ntk_dataset = exam_datasets

    grad_arr = []
    for exam_dataset in ntk_dataset:
        classification_head = get_classification_head(args, exam_dataset).to('cuda')
        model = ImageClassifier(pretrained_model, classification_head)

        if exam_dataset == 'ImageNet':
            dataset = SequentialImagenet('/home/sunwenju/dataset/', 32)
            dataloader, _ = dataset.get_data_loaders()
        else:
            dataset = get_dataset(
                exam_dataset,
                model.val_preprocess,
                location=args.data_location,
                batch_size=args.batch_size
            )
            dataloader = get_dataloader(
                dataset, is_train=True, args=args, image_encoder=None)

        # initial grad1
        grad = torch.zeros_like(get_params(pretrained_model))

        # cal grad1
        exp_size = 0
        for j, data in enumerate(dataloader):
            inputs, labels = data[0].to('cuda'), data[1].to('cuda')

            feat = pretrained_model(inputs)
            output = classification_head(feat)
            for i in range(output.size(0)):
                pretrained_model.zero_grad()
                loss = F.cross_entropy(output[i, ...].unsqueeze(0), labels[i].unsqueeze(0))
                loss.backward(retain_graph=True if i < labels.size(0) - 1 else False)
                grad += torch.abs(get_grads(pretrained_model).detach())

            exp_size += inputs.shape[0]
            if exp_size > args.exp_size:
                break
        grad = (grad / exp_size).to('cpu')

        grad_arr.append(grad)

    NTK = torch.zeros_like(
        grad_arr[0]
    )

    for i in range(len(grad_arr)):
        for j in range(i + 1, len(grad_arr)):
            vector1 = grad_arr[i]
            vector2 = grad_arr[j]
            NTK += vector1 * vector2


    num_elements = NTK.numel()
    num_top = int(args.ratio * num_elements)

    k_value = num_elements - num_top
    values, indices = torch.flatten(NTK).sort(descending=True)
    threshold = values[k_value]

    mask = (NTK < threshold).int()

    return mask

# TATR
def TATR_merging(args, task_vectors, pretrained_model, exam_datasets, order=1):
    if args.is_imageNet:
        ntk_dataset = exam_datasets + ['ImageNet']
    else:
        ntk_dataset = exam_datasets

    for pp in pretrained_model.parameters():
        pp.requires_grad = True

    grad_arr = []
    for exam_dataset in ntk_dataset:
        classification_head = get_classification_head(args, exam_dataset).to('cuda')
        model = ImageClassifier(pretrained_model, classification_head)

        if exam_dataset == 'ImageNet':
            dataset = SequentialImagenet('/home/sunwenju/dataset/', 32)
            dataloader, _ = dataset.get_data_loaders()
        else:
            dataset = get_dataset(
                exam_dataset,
                model.val_preprocess,
                location=args.data_location,
                batch_size=args.batch_size
            )
            dataloader = get_dataloader(
                dataset, is_train=True, args=args, image_encoder=None)

        # initial grad1
        grad = torch.zeros_like(get_params(pretrained_model))

        # cal grad1
        exp_size = 0
        for j, data in enumerate(dataloader):
            inputs, labels = data[0].to('cuda'), data[1].to('cuda')

            if inputs.shape[0] > args.exp_size:
                inputs, labels = inputs[:args.exp_size, ...], labels[:args.exp_size]

            feat = pretrained_model(inputs)
            output = classification_head(feat)
            for i in range(output.size(0)):
                pretrained_model.zero_grad()
                loss = F.cross_entropy(output[i, ...].unsqueeze(0), labels[i].unsqueeze(0))
                loss.backward(retain_graph=True if i < labels.size(0) - 1 else False)
                grad += torch.abs(get_grads(pretrained_model).detach())

            exp_size += inputs.shape[0]
            if exp_size >= args.exp_size:
                break
        grad = (grad / exp_size).to('cpu')

        grad_arr.append(grad)

    Omega = torch.zeros_like(
        torch.cat([task_vectors[0].vector[name].view(-1) for name, param in pretrained_model.named_parameters()])
    )

    for i in range(len(grad_arr)):
        for j in range(len(grad_arr)):
            if i != j:
                vector1 = grad_arr[i]
                vector2 = torch.abs(torch.cat([task_vectors[j].vector[name].view(-1) for name, param in pretrained_model.named_parameters()]))
                Omega += vector1 * vector2

                if order == 2:
                    Omega += 0.5 * vector2 * vector1 * vector1 * vector2


    num_elements = Omega.numel()
    num_top = int(args.ratio * num_elements)

    k_value = num_elements - num_top
    values, indices = torch.flatten(Omega).sort(descending=True)
    threshold = values[k_value]

    mask = (Omega < threshold).int()

    return mask

# TATR zero-shot version
def TATR_mergingnn(args, task_vectors, pretrained_model):

    Omega = torch.zeros_like(
        torch.cat([task_vectors[0].vector[name].view(-1) for name, param in pretrained_model.named_parameters()])
    )
    for i in range(len(task_vectors)):
        for j in range(len(task_vectors)):
            if i != j:
                vector1 = torch.abs(torch.cat([task_vectors[i].vector[name].view(-1) for name, param in pretrained_model.named_parameters()]))
                vector2 = torch.abs(torch.cat([task_vectors[j].vector[name].view(-1) for name, param in pretrained_model.named_parameters()]))
                Omega += vector1 * vector2

    num_elements = Omega.numel()
    num_top = int(args.ratio * num_elements)

    k_value = num_elements - num_top
    values, indices = torch.flatten(Omega).sort(descending=True)
    threshold = values[k_value]

    mask = (Omega < threshold).int()

    return mask