import torch
import torch.nn.functional as F
import random
import math


class FlowAugmentorTensor:
    def __init__(self, max_shift=24, max_scale=0.05, max_rotate_deg=3):
        self.max_shift = max_shift
        self.max_scale = max_scale
        self.max_rotate_rad = math.radians(max_rotate_deg)

    def spatial_transform(self, img1, img2, pseudo_img2, flow, valid=None):
        B, C, H, W = img1.shape
        out_imgs1, out_imgs2, out_pseudo_img2, out_flows, out_valids = [], [], [], [], []

        for i in range(B):
            tx = torch.empty(1).uniform_(-self.max_shift / W, self.max_shift / W).item()
            ty = torch.empty(1).uniform_(-self.max_shift / H, self.max_shift / H).item()

            scale = 1 + torch.empty(1).uniform_(-self.max_scale, self.max_scale).item()

            angle = torch.empty(1).uniform_(-self.max_rotate_rad, self.max_rotate_rad).item()
            cos_a, sin_a = math.cos(angle), math.sin(angle)

            # affine transformation matrix
            theta = torch.tensor([
                [scale * cos_a, -scale * sin_a, tx],
                [scale * sin_a,  scale * cos_a, ty]
            ], dtype=torch.float32, device=img1.device).unsqueeze(0)

            grid = F.affine_grid(theta, size=(1, C, H, W), align_corners=True)

            img1_i = F.grid_sample(img1[i:i+1], grid, align_corners=True, mode='bilinear')
            img2_i = F.grid_sample(img2[i:i+1], grid, align_corners=True, mode='bilinear')
            pseudo_img2_i = F.grid_sample(pseudo_img2[i:i+1], grid, align_corners=True, mode='bilinear')
            flow_i = F.grid_sample(flow[i:i+1], grid, align_corners=True, mode='bilinear')
            flow_i /= scale  # rescale flow

            if valid is not None:
                valid_i = F.grid_sample(valid[i:i+1].float(), grid, align_corners=True, mode='nearest')
            else:
                valid_i = None

            out_imgs1.append(img1_i[0])
            out_imgs2.append(img2_i[0])
            out_pseudo_img2.append(pseudo_img2_i[0])
            out_flows.append(flow_i[0])
            if valid is not None:
                out_valids.append(valid_i[0])

        img1_out = torch.stack(out_imgs1)
        img2_out = torch.stack(out_imgs2)
        pseudo_img2_out = torch.stack(out_pseudo_img2)
        flow_out = torch.stack(out_flows)
        valid_out = torch.stack(out_valids) if valid is not None else None

        return img1_out, img2_out, pseudo_img2_out, flow_out, valid_out

    def __call__(self, img1, img2, pseudo_img2, flow, valid=None):
        if valid is not None and valid.ndim == 3:
            valid = valid.unsqueeze(1)
        img1, img2, pseudo_img2, flow, valid = self.spatial_transform(img1, img2, pseudo_img2, flow, valid)
        if valid is not None:
            valid = valid.squeeze(1)
        return img1, img2, pseudo_img2, flow, valid
