import torch
import math
import copy


def iso_c(task_vectors):
    print("Computing SVD...")
    with torch.no_grad():
        new_vector = copy.deepcopy(sum(task_vectors))
        for key in task_vectors[0].vector:
            new_vector.vector[key] = new_vector.vector[key] / len(task_vectors)

            if len(task_vectors[0].vector[key].shape) == 2 and "text_projection" not in key:
                new_vector.vector[key] *= len(task_vectors)

                new_vector.vector[key].to('cuda')
                U, S, V = torch.linalg.svd(new_vector.vector[key], full_matrices=False)
                S_mean = torch.ones_like(S) * S.mean()

                new_vector.vector[key] = torch.linalg.multi_dot(
                    (
                        U,
                        torch.diag(S_mean),
                        V,
                    )
                ).to('cpu')

    return new_vector


@torch.no_grad()
def iso_cts(task_vectors, args):
    device = 'cuda'
    new_vector = copy.deepcopy(sum(task_vectors))

    print("Computing SVD...")
    for key in task_vectors[0].vector:
        shape_ = task_vectors[0].vector[key].shape

        is_2d_matrix = (len(shape_) == 2) and ("text_projection" not in key) and ("token_embedding" not in key)
        if not is_2d_matrix:
            print(f"Combining by avg {key}...")
            for i, task_vector in enumerate(task_vectors):
                vec = task_vector.vector[key]
                if i == 0:
                    new_vector.vector[key] = vec.clone()
                else:
                    new_vector.vector[key] += (vec - new_vector.vector[key]) / (i + 1)
            continue

        print(f"Computing common space using sum for {key}...")
        combined_w = sum([task_vector.vector[key] for task_vector in task_vectors])

        ### Calculate the common space size (making sure that task specific space is equally divisible) ###
        common_space_index_s = int(min(shape_) * args.common_space_fraction)
        _task_specific_total_space_index_s = round((min(shape_) - common_space_index_s) / len(task_vectors)) * len(
            task_vectors)
        common_space_index_s = min(shape_) - _task_specific_total_space_index_s

        u, s, v = torch.linalg.svd(combined_w.to(device), full_matrices=False)
        combined_w = combined_w.to('cpu')
        u = u.to('cpu')
        s = s.to('cpu')
        v = v.to('cpu')
        common_space_u = u[:, :common_space_index_s]
        common_space_s = s[:common_space_index_s]
        common_space_v = v[:common_space_index_s, :]
        ###################################################################

        ### Calculate task specific space ###
        n_dims_per_task = int((min(shape_) - common_space_index_s) / len(task_vectors))
        for i, task_vector in enumerate(task_vectors):
            w = task_vector.vector[key]

            # calculate the projection onto task specific space to remove the common space
            w_ts = w - common_space_u @ common_space_u.T @ w
            u_ts, s_ts, v_ts = torch.linalg.svd(w_ts.to(device), full_matrices=False)
            w_ts.to('cpu')
            u_ts.to('cpu')
            s_ts.to('cpu')
            v_ts.to('cpu')

            if i == 0:
                combined_space_u = torch.zeros_like(u_ts)
                combined_space_s = torch.zeros_like(s_ts)
                combined_space_v = torch.zeros_like(v_ts)

            combined_space_u[:, i * n_dims_per_task: (i + 1) * n_dims_per_task] = u_ts[:, :n_dims_per_task]
            combined_space_s[i * n_dims_per_task: (i + 1) * n_dims_per_task] = s_ts[:n_dims_per_task]
            combined_space_v[i * n_dims_per_task: (i + 1) * n_dims_per_task, :] = v_ts[:n_dims_per_task, :]
        ###################################################################

        combined_space_u[:, len(task_vectors) * n_dims_per_task: len(
            task_vectors) * n_dims_per_task + common_space_index_s] = common_space_u
        combined_space_s[len(task_vectors) * n_dims_per_task: len(
            task_vectors) * n_dims_per_task + common_space_index_s] = common_space_s
        combined_space_v[len(task_vectors) * n_dims_per_task: len(
            task_vectors) * n_dims_per_task + common_space_index_s, :] = common_space_v

        ### Orthogonalize combined_space_u and combined_space_v ###
        u_combined_space_u, s_combined_space_u, v_combined_space_u = torch.linalg.svd(combined_space_u.to(device),
                                                                                      full_matrices=False)
        combined_space_u.to('cpu')
        u_combined_space_u.to('cpu')
        s_combined_space_u.to('cpu')
        v_combined_space_u.to('cpu')
        u_combined_space_v, s_combined_space_v, v_combined_space_v = torch.linalg.svd(combined_space_v.to(device),
                                                                                      full_matrices=False)
        combined_space_v.to('cpu')
        u_combined_space_v.to('cpu')
        s_combined_space_v.to('cpu')
        v_combined_space_v.to('cpu')
        combined_space_u = u_combined_space_u @ v_combined_space_u
        combined_space_v = u_combined_space_v @ v_combined_space_v
        ###################################################################

        combined_space_s = torch.ones_like(combined_space_s) * combined_space_s.mean()

        new_vector.vector[key] = torch.linalg.multi_dot(
            (
                combined_space_u,
                torch.diag(combined_space_s),
                combined_space_v,
            )
        ).to('cpu')

    return new_vector