import torch
import torch.nn as nn
import numpy as np
import math


def wrap_torch_model(torch_model, weights, start_index, exclude_modules=None, add_init_weights=True):
    all_param_tensors = []
    exclude_modules = set() if exclude_modules is None else exclude_modules
    G = weights.size(1)
    for m in torch_model.modules():
        if m in exclude_modules:
            continue
        if type(m) is nn.Linear:
            flattened_start_index = G * start_index
            out_features, in_features = m.weight.size()
            flattened_size = list(m.weight.flatten().size())[0]
            if not hasattr(m, 'initial_weight'):
                setattr(m, 'initial_weight', m.weight.data.clone())
                m.initial_weight.requires_grad = False
            del m.weight
            m.weight = (weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]
                        .view(in_features, out_features).permute(1, 0))# + m.initial_weight
            if add_init_weights:
                m.weight = m.weight + m.initial_weight
            flattened_start_index += flattened_size
            all_param_tensors.append(m.weight)
            if m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                if not hasattr(m, 'initial_bias'):
                    setattr(m, 'initial_bias', m.bias.data.clone())
                    m.initial_bias.requires_grad = False
                del m.bias
                m.bias = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]# + m.initial_bias
                if add_init_weights:
                    m.bias = m.bias + m.initial_bias
                flattened_start_index += flattened_size
                all_param_tensors.append(m.bias)
            start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
        elif type(m) is nn.Conv2d:
            flattened_start_index = G * start_index
            out_channels, in_channels, kernel_size_1, kernel_size_2 = m.weight.size()
            flattened_size = list(m.weight.flatten().size())[0]
            if not hasattr(m, 'initial_weight'):
                setattr(m, 'initial_weight', m.weight.data.clone())
                m.initial_weight.requires_grad = False
            del m.weight
            m.weight = (weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]
                        .view(in_channels, out_channels, kernel_size_1, kernel_size_2).permute(1, 0, 2, 3))# + m.initial_weight
            if add_init_weights:
                m.weight = m.weight + m.initial_weight
            flattened_start_index += flattened_size
            all_param_tensors.append(m.weight)
            if m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                if not hasattr(m, 'initial_bias'):
                    setattr(m, 'initial_bias', m.bias.data.clone())
                    m.initial_bias.requires_grad = False
                del m.bias
                m.bias = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]# + m.initial_bias
                if add_init_weights:
                    m.bias = m.bias + m.initial_bias
                flattened_start_index += flattened_size
                all_param_tensors.append(m.bias)
            start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
        else:
            flattened_start_index = G * start_index
            if hasattr(m, 'weight'):
                size = m.weight.size()
                flattened_size = list(m.weight.flatten().size())[0]
                if not hasattr(m, 'initial_weight'):
                    setattr(m, 'initial_weight', m.weight.data.clone())
                    m.initial_weight.requires_grad = False
                del m.weight
                m.weight = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size].view(size)# + m.initial_weight
                if add_init_weights:
                    m.weight = m.weight + m.initial_weight
                flattened_start_index += flattened_size
                all_param_tensors.append(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                if not hasattr(m, 'initial_bias'):
                    setattr(m, 'initial_bias', m.bias.data.clone())
                    m.initial_bias.requires_grad = False
                del m.bias
                m.bias = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]# + m.initial_bias
                if add_init_weights:
                    m.bias = m.bias + m.initial_bias
                flattened_start_index += flattened_size
                all_param_tensors.append(m.bias)
            start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
    return all_param_tensors


def wrap_torch_model_lora(torch_model, weights, start_index, exclude_modules=None, rank=64, add_init_weights=True,
                          use_large_lora_cnn=False):
    all_param_tensors = []
    exclude_modules = set() if exclude_modules is None else exclude_modules
    G = weights.size(1)
    for m in torch_model.modules():
        if m in exclude_modules:
            continue
        if type(m) is nn.Linear:
            flattened_start_index = G * start_index
            out_features, in_features = m.weight.size()
            lora_A_size = rank * out_features
            lora_B_size = rank * in_features
            a_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_A_size]
            flattened_start_index += lora_A_size
            b_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_B_size]
            flattened_start_index += lora_B_size
            if not hasattr(m, 'initial_weight'):
                setattr(m, 'initial_weight', m.weight.data.clone())
                m.initial_weight.requires_grad = False
                B_init = torch.zeros(rank, in_features, device=m.initial_weight.device)
                nn.init.kaiming_uniform_(B_init, a=math.sqrt(5))
                setattr(m, "initial_lora_B", B_init)
            del m.weight
            if add_init_weights:
                m.weight = m.initial_weight + (a_matrix.view(out_features, rank) @ (m.initial_lora_B + b_matrix.view(rank, in_features)))
            else:
                m.weight = a_matrix.view(out_features, rank) @ b_matrix.view(rank, in_features)# + m.initial_weight
            all_param_tensors.append(m.weight)
            if m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                if not hasattr(m, 'initial_bias'):
                    setattr(m, 'initial_bias', m.bias.data.clone())
                    m.initial_bias.requires_grad = False
                del m.bias
                m.bias = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]# + m.initial_bias
                if add_init_weights:
                    m.bias = m.bias + m.initial_bias
                flattened_start_index += flattened_size
                all_param_tensors.append(m.bias)
            start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
        elif type(m) is nn.Conv2d:
            if use_large_lora_cnn:
                flattened_start_index = G * start_index
                out_channels, in_channels, kernel_size_1, kernel_size_2 = m.weight.size()
                lora_A_size = rank * kernel_size_1 * kernel_size_2 * in_channels
                lora_B_size = rank * kernel_size_1 * kernel_size_2 * out_channels
                if not hasattr(m, 'initial_weight'):
                    setattr(m, 'initial_weight', m.weight.data.clone())
                    m.initial_weight.requires_grad = False
                    A_init = torch.zeros(kernel_size_1 * in_channels, rank * kernel_size_2, device=m.initial_weight.device)
                    nn.init.kaiming_uniform_(A_init, a=math.sqrt(5))
                    setattr(m, "initial_lora_A", A_init)
                del m.weight
                a_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_A_size]
                flattened_start_index += lora_A_size
                b_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_B_size]
                flattened_start_index += lora_B_size
                if add_init_weights:
                    m.weight = (((m.initial_lora_A + a_matrix.view(kernel_size_1 * in_channels, kernel_size_2 * rank)) @ (
                        b_matrix.view(rank * kernel_size_1, kernel_size_2 * out_channels)))
                                .view(out_channels, in_channels, kernel_size_1, kernel_size_2)) + m.initial_weight
                else:
                    m.weight = ((a_matrix.view(kernel_size_1 * in_channels, kernel_size_2 * rank) @ (
                        b_matrix.view(kernel_size_1 * rank, kernel_size_2 * out_channels)))
                                .view(out_channels, in_channels, kernel_size_1, kernel_size_2))

                all_param_tensors.append(m.weight)
                if m.bias is not None:
                    flattened_size = list(m.bias.size())[0]
                    if not hasattr(m, 'initial_bias'):
                        setattr(m, 'initial_bias', m.bias.data.clone())
                        m.initial_bias.requires_grad = False
                    del m.bias
                    m.bias = weights.flatten()[
                             flattened_start_index:flattened_start_index + flattened_size]  # + m.initial_bias
                    if add_init_weights:
                        m.bias = m.bias + m.initial_bias
                    flattened_start_index += flattened_size
                    all_param_tensors.append(m.bias)
                start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
            else:
                flattened_start_index = G * start_index
                out_channels, in_channels, kernel_size_1, kernel_size_2 = m.weight.size()
                lora_A_size = rank * kernel_size_1 * in_channels
                lora_B_size = rank * kernel_size_2 * out_channels
                if not hasattr(m, 'initial_weight'):
                    setattr(m, 'initial_weight', m.weight.data.clone())
                    m.initial_weight.requires_grad = False
                    A_init = torch.zeros(kernel_size_1 * in_channels, rank, device=m.initial_weight.device)
                    nn.init.kaiming_uniform_(A_init, a=math.sqrt(5))
                    setattr(m, "initial_lora_A", A_init)
                del m.weight
                a_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_A_size]
                flattened_start_index += lora_A_size
                b_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_B_size]
                flattened_start_index += lora_B_size
                if add_init_weights:
                    m.weight = (((m.initial_lora_A + a_matrix.view(kernel_size_1*in_channels, rank)) @ (b_matrix.view(rank, kernel_size_2 * out_channels)))
                                .view(out_channels, in_channels, kernel_size_1, kernel_size_2)) + m.initial_weight
                else:
                    m.weight = ((a_matrix.view(kernel_size_1*in_channels, rank) @ (b_matrix.view(rank, kernel_size_2 * out_channels)))
                                .view(out_channels, in_channels, kernel_size_1, kernel_size_2))

                all_param_tensors.append(m.weight)
                if m.bias is not None:
                    flattened_size = list(m.bias.size())[0]
                    if not hasattr(m, 'initial_bias'):
                        setattr(m, 'initial_bias', m.bias.data.clone())
                        m.initial_bias.requires_grad = False
                    del m.bias
                    m.bias = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]# + m.initial_bias
                    if add_init_weights:
                        m.bias = m.bias + m.initial_bias
                    flattened_start_index += flattened_size
                    all_param_tensors.append(m.bias)
                start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
        else:
            flattened_start_index = G * start_index
            if hasattr(m, 'weight'):
                assert len(m.weight.size()) == 2
                out_features, in_features = m.weight.size()
                lora_A_size = rank * out_features
                lora_B_size = rank * in_features
                a_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_A_size]
                flattened_start_index += lora_A_size
                b_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_B_size]
                flattened_start_index += lora_B_size
                if not hasattr(m, 'initial_weight'):
                    setattr(m, 'initial_weight', m.weight.data.clone())
                    m.initial_weight.requires_grad = False
                    B_init = torch.zeros(rank, in_features, device=m.initial_weight.device)
                    nn.init.kaiming_uniform_(B_init, a=math.sqrt(5))
                    setattr(m, "initial_lora_B", B_init)
                del m.weight
                if add_init_weights:
                    m.weight = m.initial_weight + (a_matrix.view(out_features, rank) @ (
                                m.initial_lora_B + b_matrix.view(rank, in_features)))
                else:
                    m.weight = a_matrix.view(out_features, rank) @ b_matrix.view(rank,
                                                                                 in_features)  # + m.initial_weight
                all_param_tensors.append(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                if not hasattr(m, 'initial_bias'):
                    setattr(m, 'initial_bias', m.bias.data.clone())
                    m.initial_bias.requires_grad = False
                del m.bias
                m.bias = weights.flatten()[flattened_start_index:flattened_start_index + flattened_size]# + m.initial_bias
                if add_init_weights:
                    m.bias = m.bias + m.initial_bias
                flattened_start_index += flattened_size
                all_param_tensors.append(m.bias)
            start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))
    return all_param_tensors


def get_n_sections_for_model(torch_model, gen_size, exclude_modules=None):
    exclude_modules = set() if exclude_modules is None else exclude_modules
    n_sections = 0
    for m in torch_model.modules():
        if m in exclude_modules:
            continue
        params_in_module = 0
        if hasattr(m, 'weight'):
            flattened_size = list(m.weight.flatten().size())[0]
            params_in_module += flattened_size
        if hasattr(m, 'bias') and m.bias is not None:
            flattened_size = list(m.bias.size())[0]
            params_in_module += flattened_size
        n_sections += np.ceil(params_in_module / gen_size)
    return int(n_sections)


def get_n_sections_for_model_lora(torch_model, gen_size, exclude_modules=None, rank=64, use_large_lora_cnn=False):
    exclude_modules = set() if exclude_modules is None else exclude_modules
    n_sections = 0
    for m in torch_model.modules():
        if m in exclude_modules:
            continue
        params_in_module = 0
        if type(m) is nn.Conv2d:
            if use_large_lora_cnn:
                out_channels, in_channels, kernel_size_1, kernel_size_2 = m.weight.size()
                lora_size = rank * kernel_size_1 * kernel_size_2 * in_channels + rank * kernel_size_1 * kernel_size_2 * out_channels
                params_in_module += lora_size
                if hasattr(m, 'bias') and m.bias is not None:
                    flattened_size = list(m.bias.size())[0]
                    params_in_module += flattened_size
                n_sections += np.ceil(params_in_module / gen_size)
            else:
                out_channels, in_channels, kernel_size_1, kernel_size_2 = m.weight.size()
                flattened_size = list(m.weight.flatten().size())[0]
                n_rows = kernel_size_1 * in_channels
                n_cols = kernel_size_2 * out_channels
                assert n_rows * n_cols == flattened_size
                lora_size = n_rows * rank + n_cols * rank
                params_in_module += lora_size
                if hasattr(m, 'bias') and m.bias is not None:
                    flattened_size = list(m.bias.size())[0]
                    params_in_module += flattened_size
                n_sections += np.ceil(params_in_module / gen_size)
        else:
            if hasattr(m, 'weight'):
                assert len(m.weight.size()) == 2
                in_size, out_size = m.weight.size()
                flattened_size = list(m.weight.flatten().size())[0]
                n_rows = in_size
                n_cols = out_size
                assert n_rows * n_cols >= flattened_size
                lora_size = n_rows * rank + n_cols * rank
                params_in_module += lora_size
            if hasattr(m, 'bias') and m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                params_in_module += flattened_size
            n_sections += np.ceil(params_in_module / gen_size)
    return int(n_sections)


def wrap_params_lora(torch_model, weights, start_index, param_fields, rank=64):
    G = weights.size(1)
    for param_field in param_fields:
        p = getattr(torch_model, param_field)
        flattened_start_index = G * start_index
        original_size = p.size()
        out_features, in_features = p.squeeze().size()
        lora_A_size = rank * out_features
        lora_B_size = rank * in_features
        a_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_A_size]
        flattened_start_index += lora_A_size
        b_matrix = weights.flatten()[flattened_start_index:flattened_start_index + lora_B_size]
        flattened_start_index += lora_B_size
        delattr(torch_model, param_field)
        setattr(torch_model,param_field, (a_matrix.view(out_features, rank) @ b_matrix.view(rank, in_features)).view(original_size))
        start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))

def get_n_sections_for_params_lora(model, param_fields, gen_size, rank=64):
    n_sections = 0
    for param_field in param_fields:
        p = getattr(model, param_field)
        params_in_module = 0
        assert len(p.squeeze().size()) == 2
        in_size, out_size = p.squeeze().size()
        flattened_size = list(p.flatten().size())[0]
        n_rows = in_size
        n_cols = out_size
        assert n_rows * n_cols >= flattened_size
        lora_size = n_rows * rank + n_cols * rank
        params_in_module += lora_size
        n_sections += np.ceil(params_in_module / gen_size)
    return int(n_sections)


def wrap_params(torch_model, weights, start_index, param_fields):
    G = weights.size(1)
    for param_field in param_fields:
        p = getattr(torch_model, param_field)
        flattened_start_index = G * start_index
        original_size = p.size()
        out_features, in_features = p.squeeze().size()
        total_size = out_features*in_features
        matrix = weights.flatten()[flattened_start_index:flattened_start_index + total_size]
        flattened_start_index += total_size
        delattr(torch_model, param_field)
        setattr(torch_model,param_field, matrix.view(original_size))
        start_index += int(np.ceil((flattened_start_index - (start_index * G)) / G))


def get_n_sections_for_params(model, param_fields, gen_size):
    n_sections = 0
    for param_field in param_fields:
        p = getattr(model, param_field)
        params_in_module = 0
        assert len(p.squeeze().size()) == 2
        in_size, out_size = p.squeeze().size()
        params_in_module += in_size * out_size
        n_sections += np.ceil(params_in_module / gen_size)
    return int(n_sections)


def get_input_size_per_sections_for_model(torch_model, gen_size, exclude_modules=None):
    exclude_modules = set() if exclude_modules is None else exclude_modules
    n_sections = 0
    output = []
    for m in torch_model.modules():
        if m in exclude_modules:
            continue
        params_in_module = 0
        if hasattr(m, 'weight'):
            flattened_size = list(m.weight.flatten().size())[0]
            params_in_module_weight = flattened_size
            params_in_module += params_in_module_weight
            if type(m) is nn.Linear:
                num = torch.sqrt(torch.as_tensor(2))
            else:
                # num = torch.sqrt(torch.as_tensor(3)) * torch.sqrt(torch.as_tensor(2))
                num = torch.sqrt(torch.as_tensor(2))
            if hasattr(m, 'bias') and m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                params_in_module_bias = flattened_size
                params_in_module += params_in_module_bias
                total_sections = np.ceil(params_in_module / gen_size).astype(int)

                n_sections_weight_exclude_bias = params_in_module_weight // gen_size
                output.append((n_sections, num / torch.sqrt(torch.as_tensor(m.weight.size(1)))))
                output.append(
                    (n_sections + n_sections_weight_exclude_bias, 0))
                n_sections += total_sections
            else:
                output.append((n_sections, num / torch.sqrt(torch.as_tensor(m.weight.size(1)))))
                n_sections += np.ceil(params_in_module / gen_size).astype(int)
    return output


def apply_gen_output_to_model(model, gen_size, generator_out, exclude_modules):
    start_index = 0
    n_sections = generator_out.size(0)
    for m in model.modules():
        if m in exclude_modules:
            continue
        if type(m) is nn.Linear:
            generator_out = generator_out.view(-1)

            out_features, in_features = m.weight.size()

            flattened_start_index = gen_size * start_index
            flattened_size = list(m.weight.flatten().size())[0]

            m.weight.sub_(m.weight).add_(
                generator_out[flattened_start_index: flattened_start_index + flattened_size].view(m.weight.size()).permute(1, 0))
            flattened_start_index += flattened_size
            if m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                m.bias.sub_(m.bias).add_(generator_out[
                                         flattened_start_index: flattened_start_index + flattened_size].view(
                    m.bias.size()))
                flattened_start_index += flattened_size
            start_index += int(np.ceil((flattened_start_index - (start_index * gen_size)) / gen_size))
            generator_out = generator_out.view(n_sections, gen_size)
        elif type(m) is nn.Conv2d:
            generator_out = generator_out.view(-1)
            out_channels, in_channels, kernel_size_1, kernel_size_2 = m.weight.size()

            flattened_start_index = gen_size * start_index
            flattened_size = list(m.weight.flatten().size())[0]

            m.weight.sub_(m.weight).add_(
                generator_out[flattened_start_index: flattened_start_index + flattened_size].view(m.weight.size())
                .permute(1, 0, 2, 3))
            flattened_start_index += flattened_size
            if m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                m.bias.sub_(m.bias).add_(generator_out[
                                         flattened_start_index: flattened_start_index + flattened_size].view(
                    m.bias.size()))
                flattened_start_index += flattened_size
            start_index += int(np.ceil((flattened_start_index - (start_index * gen_size)) / gen_size))
            generator_out = generator_out.view(n_sections, gen_size)
        else:
            generator_out = generator_out.view(-1)

            flattened_start_index = gen_size * start_index
            if hasattr(m, 'weight'):
                flattened_size = list(m.weight.flatten().size())[0]
                m.weight.sub_(m.weight).add_(
                    generator_out[flattened_start_index: flattened_start_index + flattened_size].view(
                        m.weight.size()))
                flattened_start_index += flattened_size
            if hasattr(m, 'bias') and m.bias is not None:
                flattened_size = list(m.bias.size())[0]
                m.bias.sub_(m.bias).add_(generator_out[
                                         flattened_start_index: flattened_start_index + flattened_size].view(
                    m.bias.size()))
                flattened_start_index += flattened_size
            start_index += int(np.ceil((flattened_start_index - (start_index * gen_size)) / gen_size))
            generator_out = generator_out.view(n_sections, gen_size)

