from __future__ import print_function
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as F_

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Env :
    def __init__(self, imgs) : 
        self.imgs = imgs
        self.batch_size = imgs.size(0)
        self.img_size = [imgs.size(2), imgs.size(3)]  # 218, 178
        self.view_size = [27, 27]
        self.scale_ratio = [1, 4, 10]
        self.view = int((self.view_size[0]-1)/2)
        self.view_size_2 = [self.view_size[0]*self.scale_ratio[1], self.view_size[0]*self.scale_ratio[1]]
        self.view2 = int((self.view_size_2[0]-1)/2)
        self.view_size_3 = [self.view_size[0]*self.scale_ratio[2], self.view_size[0]*self.scale_ratio[2]]
        self.view3 = int((self.view_size_3[0]-1)/2)
        
        self.imgs = self.apply_pad(self.imgs)   # padded img, [B, 1, 256+134, 256+134]
        
        
    def apply_pad(self, imgs) : 
        self.need_pad = self.view3+1
        pad_imgs = F.pad(imgs, (self.need_pad, self.need_pad, self.need_pad, self.need_pad), mode='constant', value=3)
        return pad_imgs
    
    
    def coord_init(self, center_init, coord, n_landmarks=29) : 
        self.n_landmarks = n_landmarks
        if center_init : 
            self.current_c = torch.ones(n_landmarks, self.batch_size, 2).to(device).long()
            self.current_c *= torch.LongTensor([self.img_size[0]//2, self.img_size[1]//2]).to(device)
        elif coord != None : 
            self.current_c = torch.ones(n_landmarks, self.batch_size, 2).to(device).long()
            self.current_c *= coord.view(-1, self.batch_size, 2)
        else : 
            self.current_c = torch.randint(0, 256, (n_landmarks, self.batch_size, 2)).to(device)
    
    
    def extract_shifted_patches(self, imgs, starting_point, shifts, patch_size):
        B, C, H, W = imgs.shape
        r = patch_size // 2

        delta = torch.arange(-r, r + 1, device=imgs.device)

        idx_h_local = starting_point + delta[:, None].repeat(1, patch_size)
        idx_w_local = starting_point + delta[None, :].repeat(patch_size, 1)

        idx_h_local = idx_h_local.unsqueeze(0)
        idx_w_local = idx_w_local.unsqueeze(0)

        shift_h = shifts[:, 0].view(B, 1, 1)
        shift_w = shifts[:, 1].view(B, 1, 1)

        idx_h = (idx_h_local + shift_h) % H 
        idx_w = (idx_w_local + shift_w) % W 

        batch_idx = torch.arange(B, device=imgs.device).view(B, 1, 1).expand(B, patch_size, patch_size)

        x = imgs.squeeze(1) 
        patches = x[batch_idx, idx_h, idx_w] 
        return patches.unsqueeze(1) 
    
    
    def c_to_o(self) : 
        imgs_tmp = self.imgs.repeat(self.n_landmarks, 1, 1, 1)
        o_0ch = self.extract_shifted_patches(imgs_tmp, self.need_pad, self.current_c.view(-1,2), 27)
        o_1ch = F_.resize(self.extract_shifted_patches(imgs_tmp, self.need_pad, self.current_c.view(-1,2), 27*4+1), self.view_size)
        o_2ch = F_.resize(self.extract_shifted_patches(imgs_tmp, self.need_pad, self.current_c.view(-1,2), 27*10+1), self.view_size)
        o = torch.cat((o_0ch, o_1ch, o_2ch), dim=1).view(self.n_landmarks, self.batch_size, 3, 27, 27)
        del imgs_tmp
        return o
    
    
    def apply_action(self, act_to_env):
        self.current_c += act_to_env
        self.current_c = torch.clamp(self.current_c, 0, 255)
        self.current_o = self.c_to_o()

