from __future__ import print_function
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as F_

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Env :
    def __init__(self, imgs, center_init=False, coord=None) : 
        self.imgs = imgs
        self.batch_size = imgs.size(0)
        self.img_size = [imgs.size(2), imgs.size(3)]
        self.view_size = [27, 27]
        self.scale_ratio = [1, 4, 10]
        self.view = int((self.view_size[0]-1)/2)
        self.view_size_2 = [self.view_size[0]*self.scale_ratio[1], self.view_size[0]*self.scale_ratio[1]]
        self.view2 = int((self.view_size_2[0]-1)/2)
        self.view_size_3 = [self.view_size[0]*self.scale_ratio[2], self.view_size[0]*self.scale_ratio[2]]
        self.view3 = int((self.view_size_3[0]-1)/2)
        
        self.imgs = self.apply_pad(self.imgs)
        
        self.coord_init(center_init, coord)
        self.current_o = self.c_to_o()
    
    
    def apply_pad(self, imgs) : 
        self.need_pad = self.view3+1
        pad_imgs = F.pad(imgs, (self.need_pad, self.need_pad, self.need_pad, self.need_pad))
        return pad_imgs
    
    
    def coord_init(self, center_init, coord) : 
        if center_init : 
            self.current_c = torch.ones(self.batch_size, 2).to(device).long()
            self.current_c *= torch.LongTensor([self.img_size[0]//2, self.img_size[1]//2]).to(device)
        elif coord != None : 
            self.current_c = torch.ones(self.batch_size, 2).to(device).long()
            self.current_c *= coord.view(-1, self.batch_size, 2)
        else : 
            self.current_c = torch.randint(0, 256, (self.batch_size, 2)).to(device)
    
    
    def extract_shifted_patches(self, imgs, starting_point, shifts, patch_size):
        B, C, H, W = imgs.shape
        r = patch_size // 2
    
        delta = torch.arange(-r, r + 1, device=imgs.device)
    
        idx_h_local = starting_point + delta[:, None].repeat(1, patch_size)
        idx_w_local = starting_point + delta[None, :].repeat(patch_size, 1)
    
        idx_h_local = idx_h_local.unsqueeze(0)
        idx_w_local = idx_w_local.unsqueeze(0)
    
        shift_h = shifts[:, 0].view(B, 1, 1)
        shift_w = shifts[:, 1].view(B, 1, 1)
    
        idx_h = (idx_h_local.view(1, patch_size, patch_size) + shift_h) % H  
        idx_w = (idx_w_local.view(1, patch_size, patch_size) + shift_w) % W 
    
        batch_idx = torch.arange(B, device=imgs.device).view(B, 1, 1).expand(B, patch_size, patch_size)
        channel_idx = torch.arange(C, device=imgs.device).view(1, C, 1, 1).expand(B, C, patch_size, patch_size)  
        
        idx_h = idx_h.unsqueeze(1).expand(B, C, patch_size, patch_size)
        idx_w = idx_w.unsqueeze(1).expand(B, C, patch_size, patch_size)
        batch_idx = batch_idx.unsqueeze(1).expand(B, C, patch_size, patch_size)  
        
        patches = imgs[batch_idx, channel_idx, idx_h, idx_w]
        return patches 
        
    
    
    def c_to_o(self) :
        o_0ch = self.extract_shifted_patches(self.imgs, self.need_pad, self.current_c, 27)
        o_1ch = F_.resize(self.extract_shifted_patches(self.imgs, self.need_pad, self.current_c, 27*4+1), self.view_size)
        o_2ch = F_.resize(self.extract_shifted_patches(self.imgs, self.need_pad, self.current_c, 27*10+1), self.view_size)
        o = torch.cat((o_0ch, o_1ch, o_2ch), dim=1).view(self.batch_size, 9, 27, 27)
        return o
    
    
    def c_to_o_cand(self, c1_cand) : 
        imgs_tmp = self.imgs.repeat(1, 8, 1, 1).view(-1, 1, self.imgs.size(2), self.imgs.size(3))
        o_0ch = self.extract_shifted_patches(imgs_tmp, self.need_pad, c1_cand.reshape(-1,2), 27)
        o_1ch = F_.resize(self.extract_shifted_patches(imgs_tmp, self.need_pad, c1_cand.reshape(-1,2), 27*4+1), self.view_size)
        o_2ch = F_.resize(self.extract_shifted_patches(imgs_tmp, self.need_pad, c1_cand.reshape(-1,2), 27*10+1), self.view_size)
        o = torch.cat((o_0ch, o_1ch, o_2ch), dim=1).view(self.batch_size, 8, 9, 27, 27)
        del imgs_tmp
        return o
    
    
    def apply_action_and_o_extraction(self, act_to_env):
        c1_cand = self.current_c.repeat(1,8).view(self.batch_size, 8, 2)
        c1_cand += act_to_env
        c1_cand = torch.clamp(c1_cand, 0, 255)
        o1_cand = self.c_to_o_cand(c1_cand)
        return o1_cand
    






