import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class qkv_super(nn.Linear):
    def __init__(self, super_in_dim, super_out_dim, bias=True, uniform_=None, non_linear='linear', scale=False,LoRA_dim=1024):
        super().__init__(super_in_dim, super_out_dim, bias=bias)

        # super_in_dim and super_out_dim indicate the largest network!
        self.super_in_dim = super_in_dim
        self.super_out_dim = super_out_dim

        # input_dim and output_dim indicate the current sampled size
        self.sample_in_dim = None
        self.sample_out_dim = None

        self.samples = {}

        self.scale = scale
        # self._reset_parameters(bias, uniform_, non_linear)
        self.profiling = False

        self.super_LoRA_dim = LoRA_dim

        self.LoRA_a = nn.Parameter(torch.zeros(super_in_dim, LoRA_dim))
        nn.init.kaiming_uniform_(self.LoRA_a, a=math.sqrt(5))
        self.LoRA_b = nn.Parameter(torch.zeros(LoRA_dim, super_out_dim))


    def profile(self, mode=True):
        self.profiling = mode

    def sample_parameters(self, resample=False):
        if self.profiling or resample:
            return self._sample_parameters()
        return self.samples

    def _reset_parameters(self, bias, uniform_, non_linear):
        nn.init.xavier_uniform_(self.weight) if uniform_ is None else uniform_(
            self.weight, non_linear=non_linear)
        if bias:
            nn.init.constant_(self.bias, 0.)

    def set_sample_config(self, sample_in_dim, sample_out_dim, sample_LoRA_dim):
        self.sample_in_dim = sample_in_dim
        self.sample_out_dim = sample_out_dim

        self.sample_LoRA_dim = sample_LoRA_dim

        self._sample_parameters()

    def _sample_parameters(self):
        if self.sample_LoRA_dim != 0:
            self.weight_with_LoRA = self.weight+(self.LoRA_a[:,:self.sample_LoRA_dim] @ self.LoRA_b[:self.sample_LoRA_dim,:]).T
            self.samples['weight'] = sample_weight(self.weight_with_LoRA, self.sample_in_dim, self.sample_out_dim)
        else:
            self.samples['weight'] = sample_weight(self.weight, self.sample_in_dim, self.sample_out_dim)
        self.samples['bias'] = self.bias
        self.sample_scale = self.super_out_dim/self.sample_out_dim
        if self.bias is not None:
            self.samples['bias'] = sample_bias(self.bias, self.sample_out_dim)

        return self.samples

    def forward(self, x):
        self.sample_parameters()
        return F.linear(x, self.samples['weight'], self.samples['bias']) * (self.sample_scale if self.scale else 1)

    def calc_sampled_param_num(self):
        assert 'weight' in self.samples.keys()
        weight_numel = self.samples['weight'].numel()

        if self.samples['bias'] is not None:
            bias_numel = self.samples['bias'].numel()
        else:
            bias_numel = 0

        return weight_numel + bias_numel
    def get_complexity(self, sequence_length):
        total_flops = 0
        total_flops += sequence_length *  np.prod(self.samples['weight'].size())
        return total_flops

def sample_weight(weight, sample_in_dim, sample_out_dim):

    sample_weight = weight[:, :sample_in_dim]
    sample_weight = torch.cat([sample_weight[i:sample_out_dim:3, :] for i in range(3)], dim =0)

    return sample_weight


def sample_bias(bias, sample_out_dim):
    sample_bias = bias[:sample_out_dim]

    return sample_bias
