import torch
from torch import nn
from einops import rearrange
from models.shiftresnet import ShiftConv, ShiftResNet20, Shift3x3
from torch.linalg import svd

class UnevenShift3x3(nn.Module):
    def __init__(self, shift_dist):
        super().__init__()
        self.shift_dist =  torch.cumsum(shift_dist, dim=0) # cumulative sum 

    def forward(self, x):
        return self.shift(x)

    def shift(self, x):
        b, c, h, w = x.size()
        out = torch.zeros_like(x)
        g1, g2, g3, g4, g5, g6, g7, g8, g9 = self.shift_dist
        out[:,:g1,1:,1:] = x[:,:g1,:h-1,:w-1] #right-down
        out[:,g1:g2,1:,:] = x[:,g1:g2,:h-1,:] #down
        out[:,g2:g3,1:,:w-1] = x[:,g2:g3,:h-1,1:] #left-down

        out[:,g3:g4,:,1:] = x[:,g3:g4,:,:w-1] # right
        out[:,g4:g5,:,:] = x[:,g4:g5,:,:] # center
        out[:,g5:g6,:,:w-1] = x[:,g5:g6,:,1:] # left

        out[:,g6:g7,:h-1,1:] = x[:,g6:g7,1:,:w-1] # right-up
        out[:,g7:g8,:h-1,:] = x[:,g7:g8,1:,:] #up
        out[:,g8:,:h-1,:w-1] = x[:,g8:,1:,1:] # left-up
        return out

class ShiftPruning(ShiftConv):
    def __init__(self, shift_conv_module, density=0.5, fold_bn=False, uneven=False):
        k2 = 9 # number of shift groups for 3x3 Shift
        pw1 = shift_conv_module.conv1 #1st pw
        pw2 = shift_conv_module.conv2 #2nd pw 
        c_e, c_in, _, _ = pw1.weight.shape #c_e = c_in * expansion rate e
        c_out, _, _, _ = pw2.weight.shape 
        pw1_4d_weight = pw1.weight 
        pw2_4d_weight = pw2.weight

        # fold BN to pw
        if fold_bn:
            print('fold bn to conv...')
            bn1 = shift_conv_module.bn1
            bn1_beta = bn1.weight
            bn1_var = bn1.running_var + bn1.eps
            pw1_4d_weight = pw1_4d_weight * (bn1_beta/torch.sqrt(bn1_var)).reshape(pw1.out_channels, 1, 1, 1)

            bn2 = shift_conv_module.bn2
            bn2_beta = bn2.weight
            bn2_var = bn2.running_var + bn2.eps
            pw2_4d_weight = pw2_4d_weight * (bn2_beta/torch.sqrt(bn2_var)).reshape(pw2.out_channels, 1, 1, 1)

        vt = pw1_4d_weight.reshape(k2, c_e//k2, c_in) #k^2, c_e//k2, c_in
        u = pw2_4d_weight.transpose(0,1).reshape(k2, c_e//k2, c_out).transpose(2,1) # k^2, c_out, c_e//k2
        A = u @ vt #k^2, c_out, c_in

        # SVD
        num_svs = int((c_e//k2)*density) # num of channels per shift after pruning
        new_c_e = int(num_svs*k2)
        print("New bottleneck width: ", new_c_e)
        new_e = new_c_e/c_out
        print("New expansion rate: ", new_e)
        U, S, Vh = svd(A, full_matrices=False)
        if not uneven:
            # evenly select from each shift group
            U_cut = U[:, :, :num_svs]
            S_cut = S[:, :num_svs]
            Vh_cut = Vh[:, :num_svs, :]
            # print(U_cut.shape)
            # print(S_cut.shape)
            # print(Vh_cut.shape)
            # error = torch.dist(A, U_cut @ torch.diag_embed(S_cut) @ Vh_cut)
            # print("Approximation error: ", error)

            pw1_2d_weight = rearrange(torch.diag_embed(torch.sqrt(S_cut)) @ Vh_cut, 's i j->(s i) j') #new_c_e, c_in 
            pw2_2d_weight = rearrange(U_cut @ torch.diag_embed(torch.sqrt(S_cut)), 's i j->i (s j)') #c_out, new_c_e
            # print(f'Pruned pw1 weight shape: {pw1_2d_weight.shape}')
            # print(f'Pruned pw2 weight shape: {pw2_2d_weight.shape}')
        else:
            # select across shift group
            S_cut = self.topk_sv(S, new_c_e)
            shift_dist = torch.count_nonzero(S_cut, dim=1)
            US_cut = U @ torch.diag_embed(torch.sqrt(S_cut))
            SVh_cut = torch.diag_embed(torch.sqrt(S_cut)) @ Vh
            pw1_2d_weight = rearrange(SVh_cut, 's i j->(s i) j') #c_e, c_in 
            pw2_2d_weight = rearrange(US_cut, 's i j->i (s j)') #c_out, c_e
            pw1_2d_weight = pw1_2d_weight[torch.abs(pw1_2d_weight).sum(dim=1) > 0] # remove zero rows
            pw2_2d_weight = pw2_2d_weight[:, torch.abs(pw2_2d_weight).sum(dim=0) > 0] # remove zero cols
            # print(f'Pruned pw1 weight shape: {pw1_2d_weight.shape}')
            # print(f'Pruned pw2 weight shape: {pw2_2d_weight.shape}')
            print(f'Shift distribution: {shift_dist}')


        super(ShiftPruning, self).__init__(c_in, c_out, stride=pw2.stride[0], expansion=new_e)
        self.conv1.weight.data = pw1_2d_weight.reshape(new_c_e, c_in, 1, 1)
        self.conv2.weight.data = pw2_2d_weight.reshape(c_out, new_c_e, 1, 1)

        if uneven:
            self.shift2 = UnevenShift3x3(shift_dist)


    def topk_sv(self, x, k):
        x_f = x.flatten()
        v, idx = x_f.topk(k)
        x_new = torch.zeros_like(x_f)
        x_new[idx] = v
        return x_new.reshape(x.size())

if __name__ == "__main__":
    from collections import OrderedDict
    n_classes = 10
    model = ShiftResNet20(expansion=9, num_classes=n_classes)
    ckpt_dir = 'output/cifar10/shiftnet/shiftresnet20/lr005_epoch200/ckpt.pth'
    ckpt = torch.load(ckpt_dir)
    state_dict = ckpt['model']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

    print(model)
    def count_parameters(model):
        params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return params
    print("Number of params before pruning: ", count_parameters(model))

    def replace_layers(model, old, new, density=0.5, fold_bn=False, uneven=False):
        for n, module in model.named_children():
            if len(list(module.children())) > 0:
                replace_layers(module, old, new, density=density, fold_bn=fold_bn, uneven=uneven)
                
            if isinstance(module, old):
                setattr(model, n, new(module, density=density, fold_bn=fold_bn, uneven=uneven))

    expansion = 4.5
    replace_layers(model, ShiftConv, ShiftPruning, density=(expansion/9), fold_bn=False, uneven=True)

    print(model)
    print("Number of params after pruning: ", count_parameters(model))
    x = torch.rand(10, 3, 32, 32)
    out = model(x)
    print(out.shape)
