import torch
import argparse

def get_args_parser():
    parser = argparse.ArgumentParser('ckp transformation script', add_help=False)
    parser.add_argument('--in-chans', default=3, type=int)
    parser.add_argument('--out-chans', default=3, type=int)
    parser.add_argument('--num-heads', default=9, type=int)
    parser.add_argument('--patch-size', default=16, type=int)
    parser.add_argument('--num-layer', default=6, type=int)
    parser.add_argument('--src-ckp', default='../ckp/checkpoint.pth', type=str)
    parser.add_argument('--tgt-ckp', default='../ckp/transformed_ckp6.pth', type=str)
    parser.add_argument('--kernel-size', default=3, type=int)
    return parser

def main(args):
    in_chans = args.in_chans
    out_chans = args.out_chans
    num_heads = args.num_heads
    patch_size = args.patch_size
    dim = patch_size * patch_size * in_chans
    num_layer = args.num_layer
    src_ckp = args.src_ckp
    transformed_ckp = args.tgt_ckp
    kernel_size = args.kernel_size

    src_model = torch.load(src_ckp)
    print(src_model['model'].keys())

    # src_model['model']['blocks.0.attn.proj.weight']=torch.zeros()

    for i in range(num_layer):
        block_name = 'blocks.' + str(i)

        wo = torch.zeros([dim, num_heads*dim])

        receptive_x = torch.arange(-(kernel_size-1)//2, (kernel_size+1)//2)
        receptive_y = torch.arange(-(kernel_size-1)//2, (kernel_size+1)//2)
        receptive_field = torch.stack(torch.meshgrid([receptive_x, receptive_x]))
        receptive_field = torch.flatten(receptive_field, 1).transpose(0,1)
        # receptive_field = torch.tensor([[-1,-1],[-1,0],[-1,1],[0,-1],[0,0],[0,1],[1,-1],[1,0],[1,1]])
        # print(receptive_field)
        visible_pixels = torch.zeros_like(receptive_field)

        for p in range(patch_size*patch_size): #enumerating pixels
            idx = p // patch_size
            idy = p % patch_size
            visible_pixels[:,0] = receptive_field[:,0] + idx
            visible_pixels[:,1] = receptive_field[:,1] + idy
            # print(visible_pixels)
            location = torch.ones_like(receptive_field)
            # location = torch.tensor([[1,1],[1,1],[1,1],[1,1],[1,1],[1,1],[1,1],[1,1],[1,1]])
            location[visible_pixels<0] -= 1
            visible_pixels[visible_pixels<0] += patch_size
            location[visible_pixels>=patch_size] += 1
            visible_pixels[visible_pixels>=patch_size] -= patch_size
            location[:,0] *= in_chans
            location_id = location.sum(-1)
            visible_pixels[:,0] *= patch_size
            visible_id = visible_pixels.sum(-1)
            # print(visible_id.shape)
            # print(location_id)
            for i in range(kernel_size*kernel_size):
                pixel = location_id[i]*patch_size*patch_size*in_chans + visible_id[i]*in_chans
                wo[p*out_chans:(p+1)*out_chans, pixel:pixel+in_chans] = \
                    src_model['model'][block_name+'.attn.conv.weight'][:,:,i//kernel_size,i%kernel_size]

        del src_model['model'][block_name+'.attn.conv.weight']
        src_model['model'][block_name+'.attn.proj.weight'] = wo

    torch.save(src_model, transformed_ckp)
    print('transformed ckp saved')

if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    print(args)
    main(args)
