import os, random, sys
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
from utils.flow_viz import *


class Logger_(object):
    def __init__(self, filename=None, stream=sys.stdout):
        self.terminal = stream
        self.log = open(filename, 'a')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class InputPadder:
    """ Pads images such that dimensions are divisible by 8 """
    def __init__(self, dims, mode='sintel', divis_by=8):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
        pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
        if mode == 'sintel':
            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
        else:
            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]

    def pad(self, *inputs):
        assert all((x.ndim == 4) for x in inputs)
        return [F.pad(x, self._pad, mode='replicate') for x in inputs]

    def unpad(self, x):
        assert x.ndim == 4
        ht, wd = x.shape[-2:]
        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
        return x[..., c[0]:c[1], c[2]:c[3]]
    
    
def warp_image_with_flow(img, flow):
    """
    :param img: [B, 3, H, W]
    :param flow: [B, 2, H, W]
    """
    B, C, H, W = img.size()
    
    if flow.size(1) == 1:
        flow_y = torch.zeros_like(flow)
        flow = torch.cat([flow, flow_y], dim=1)  # [B,2,H,W]
        
    y, x = torch.meshgrid(
        torch.linspace(-1, 1, H, device=img.device),
        torch.linspace(-1, 1, W, device=img.device),
        indexing='ij' 
    )
    grid = torch.stack((x, y), dim=2)  # [H, W, 2]
    grid = grid.unsqueeze(0).expand(B, H, W, 2)  # [B, H, W, 2] range[-1, 1]

    flow_norm = torch.zeros_like(flow)
    flow_norm[:, 0, :, :] = flow[:, 0, :, :] / ((W - 1) / 2)
    flow_norm[:, 1, :, :] = flow[:, 1, :, :] / ((H - 1) / 2)
    flow_norm = flow_norm.permute(0, 2, 3, 1)  # [B, H, W, 2]

    sampling_grid = grid + flow_norm  # [B, H, W, 2]
    warped = F.grid_sample(img, sampling_grid, mode='bilinear', padding_mode='zeros', align_corners=True)

    return warped  # [B, 3, H, W]


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

