import torch
import copy
import torch.nn.functional as F

from datasets.common import get_dataloader
from datasets.registry import get_dataset

from collections import defaultdict
import numpy as np


def quantile_via_sorting(tensor, q):

    flat_tensor = tensor.view(-1)

    sorted_tensor, _ = torch.sort(flat_tensor)

    index = int(q * (sorted_tensor.numel() - 1))

    return sorted_tensor[index]


def opt_merging(args, task_vectors, pretrained_checkpoint, exam_datasets, cur_task_vectors=None):
    if cur_task_vectors is None:
        cur_task_vectors = copy.deepcopy(task_vectors)
    opt_tv = copy.deepcopy(sum(task_vectors))

    # tackle the bias
    model = cur_task_vectors[0].apply_to(pretrained_checkpoint, scaling_coef=1).to('cuda')
    for name, pp in list(model.named_parameters()):
            opt_tv.vector[name] = opt_tv.vector[name] / len(task_vectors)

    feat_arr = []

    for t_id, exam_dataset in enumerate(exam_datasets):
        tv = cur_task_vectors[t_id]
        model = tv.apply_to(pretrained_checkpoint, scaling_coef=1).to('cuda')

        dataset = get_dataset(
            exam_dataset,
            model.val_preprocess,
            location=args.data_location,
            batch_size=args.batch_size
        )
        dataloader = get_dataloader(
            dataset, is_train=True, args=args, image_encoder=None)

        accumulated_features = defaultdict(list)
        accumulated_activations = defaultdict(list)

        def hook_fn(module, input, output):
            accumulated_features[module].append(input[0].detach().cpu())  
            accumulated_activations[module].append(output[0].detach().cpu())  

        hooks = []
        for name, param in model.named_parameters():
            module_name = ".".join(name.split(".")[:-1])  

            if module_name:  
                module = dict(model.named_modules())[module_name]
                hook = module.register_forward_hook(hook_fn)
                hooks.append(hook)

        exp_size = 0
        for j, data in enumerate(dataloader):
            inputs = data[0].to('cuda')
            if inputs.shape[0] > args.exp_size:
                inputs = inputs[:args.exp_size, ...]

            with torch.no_grad():
                model(inputs)

            exp_size += inputs.shape[0]
            if exp_size >= args.exp_size:
                break

        # remove hooks
        for hook in hooks:
            hook.remove()

        feat = {}
        model.to('cpu')
        for name, param in model.named_parameters():

            module_name = ".".join(name.split(".")[:-1])
            if module_name:
                module = dict(model.named_modules())[module_name]
                input_features = accumulated_features.get(module)
                output_features = accumulated_activations.get(module)
                if input_features is None:
                    feat[name] = None
                elif 'ln' in name and 'weight' in name:
                    if 'ln_post' in name:
                        continue
                    pass
                    output_features_ = torch.cat(output_features, dim=0)
                    output_features_ = output_features_.view(-1, output_features_.shape[-1])  # bs * f_out

                    weight = param
                    bias_name = name.rsplit('.', 1)[0] if name.endswith('.weight') else name
                    bias_name += '.bias'
                    bias = model.state_dict()[bias_name]

                    feat[name] = (output_features_ - bias) / weight


                elif name == 'model.visual.class_embedding' \
                        or name == 'model.visual.conv1.weight' \
                        or name == 'model.visual.positional_embedding'\
                        or name == 'model.visual.proj'\
                        or len(tv.vector[name].shape) == 1\
                        or 'bias' in name:
                    continue
                else:

                    input_feat_flat = torch.cat(input_features, dim=0)
                    input_feat_flat = input_feat_flat.view(-1, input_feat_flat.shape[-1])  # bs * f_in

                    feat[name] = input_feat_flat
        feat_arr.append(feat)


    for name, pp in list(model.named_parameters()):
        if opt_tv.vector[name] is None:
            continue
        elif name == 'model.visual.class_embedding' \
                or name == 'model.visual.conv1.weight' \
                or name == 'model.visual.positional_embedding' \
                or name == 'model.visual.proj' \
                or len(tv.vector[name].shape) == 1 \
                or 'bias' in name:
            continue
        elif feat_arr[0][name] is None:
            continue
        elif 'ln' in name and 'weight' in name:
            numerator = 0.0  
            denominator = 0.0  
            for t_id, tv in enumerate(cur_task_vectors):
                X_k = feat_arr[t_id][name]
                T_k = tv.vector[name]
                X_k_squared_sum = torch.sum(X_k ** 2, dim=0)
                numerator += X_k_squared_sum * T_k
                denominator += X_k_squared_sum  
            denominator = torch.where(denominator == 0, torch.tensor(1e-10), denominator)
            opt_tv.vector[name] = numerator / denominator

        else:
            sum_XTX = 0.0
            for t_id, _ in enumerate(cur_task_vectors):
                X_k = feat_arr[t_id][name]
                sum_XTX += X_k.T @ X_k

            sum_XTT = 0.0
            for t_id, tv in enumerate(cur_task_vectors):
                X_k = feat_arr[t_id][name]
                T_k = tv.vector[name]
                sum_XTT += X_k.T @ X_k @ T_k.T

            T_optimal = torch.linalg.pinv(sum_XTX) @ sum_XTT
            opt_tv.vector[name] = T_optimal.T




    return opt_tv


