import os
import cv2
import torch
import pandas as pd
import numpy as np
import torch.nn.functional as F

from typing import Callable, Optional, List, Tuple
from torch.utils.data import Dataset
from torchvision.transforms import functional as TF


class ImageDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2

        # Assuming both datasets have the same length
        assert len(dataset1) == len(dataset2), "Datasets must have the same length."

    def __len__(self):
        return len(self.dataset1)

    def __getitem__(self, index):
        # Get items from both datasets
        image1 = self.dataset1[index][0]
        image2 = self.dataset2[index][0]

        # Return a tuple of concatenated images and labels
        return image1, image2

class WallerDataset(Dataset):
    def __init__(self, path, train=False, transform_raw=None, transform_lab=None):
        self.path = path
        self.transform_raw = transform_raw
        self.transform_lab = transform_lab
        if train: self.df = pd.read_csv(self.path + '/dataset_train.csv')
        else: self.df = pd.read_csv(self.path + '/dataset_test.csv')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        raw_path = (self.path + '/diffuser_images' + '/' + self.df.iloc[idx, 0]).replace('.jpg.tiff', '.npy')
        lab_path = (self.path + '/ground_truth_lensed' + '/' + self.df.iloc[idx, 0]).replace('.jpg.tiff', '.npy')
        raw = np.load(raw_path) 
        raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
        lab = np.load(lab_path) 
        lab = cv2.cvtColor(lab, cv2.COLOR_BGR2RGB)

        if self.transform_raw is not None:
            raw = self.transform_raw(raw)
        if self.transform_lab is not None:
            lab = self.transform_lab(lab)

        return raw, lab


class pad_affine_crop_transform:
    def __init__(self, target_size=(260, 260), pad=0, bbox=None, M=None, interp="bicubic", align_corners=False):
        self.target_size = target_size
        self.pad = pad
        self.align_corners = align_corners

        if M is None:
            M = [[1.00125533, -0.00271183, 0.53230019],
                 [0.00271183,  1.00125533, -0.61541837]]

        self.bbox = bbox
        self.M = torch.tensor(M, dtype=torch.float)

        self.interp_mode = "bicubic" if interp == "bicubic" else "bilinear"

    def __call__(self, src):
        # 1) resize -> 2) pad
        src = TF.resize(src, self.target_size, interpolation=TF.InterpolationMode.BICUBIC)
        src = TF.pad(src, padding=[self.pad, self.pad, self.pad, self.pad])

        # 3) bbox
        if self.bbox is not None:
            y_min, x_min, y_max, x_max = self.bbox
        else:
            y_min, x_min, y_max, x_max = 0, 0, src.shape[1], src.shape[2]

        # 4) affine (OpenCV forward M을 grid_sample에 쓰려면 역맵을 사용해야 함)
        C, H, W = src.shape
        device, dtype = src.device, src.dtype
        M_pix = _ensure_tensor(self.M, device, dtype)  # (1,2,3)
        theta = _opencv_M_to_theta(
            M_pix, in_hw=(H, W), out_hw=(H, W),
            device=device, dtype=dtype,
            align_corners=self.align_corners,
            use_backward_map=True  # = invert(M) internally
        )

        grid = F.affine_grid(theta, size=(1, C, H, W), align_corners=self.align_corners)
        warped = F.grid_sample(src.unsqueeze(0), grid,
                               mode=self.interp_mode,
                               padding_mode="zeros",
                               align_corners=self.align_corners).squeeze(0)

        masked = torch.zeros_like(src)
        masked[:, y_min:y_max+1, x_min:x_max+1] = warped[:, y_min:y_max+1, x_min:x_max+1]
        return masked

def inverse_affine(src, M, interp_mode="bicubic", align_corners=False):
    """
    src: (C,H,W) or (B,C,H,W)
    M  : (2,3) or (B,2,3)  — forward pixel-space M (OpenCV warpAffine용)
    """
    is_batched = (src.dim() == 4)
    if not is_batched:
        src = src.unsqueeze(0)  # (1,C,H,W)

    B, C, H, W = src.shape
    device, dtype = src.device, src.dtype
    M_pix = _ensure_tensor(M, device, dtype)  # (b? ,2,3)
    if M_pix.shape[0] == 1 and B > 1:
        M_pix = M_pix.expand(B, -1, -1)

    theta = _opencv_M_to_theta(
        M_pix, in_hw=(H, W), out_hw=(H, W),
        device=device, dtype=dtype,
        align_corners=align_corners,
        use_backward_map=False  # use M as backward map (dst->src), i.e., undo previous step
    )

    grid = F.affine_grid(theta, size=(B, C, H, W), align_corners=align_corners)
    out = F.grid_sample(src, grid,
                        mode=("bicubic" if interp_mode=="bicubic" else "bilinear"),
                        padding_mode="zeros",
                        align_corners=align_corners)
    return out if is_batched else out.squeeze(0)

import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF

def _ensure_tensor(M, device, dtype):
    if not torch.is_tensor(M):
        M = torch.as_tensor(M, device=device, dtype=dtype)
    else:
        M = M.to(device=device, dtype=dtype)
    if M.ndim == 2:
        M = M.unsqueeze(0)  # (1,2,3)
    return M  # (B,2,3)

def _norm_mats(H, W, device, dtype, align_corners=False):
    I = torch.eye(3, device=device, dtype=dtype)
    if align_corners:
        sx = 2.0 / max(W - 1, 1)
        sy = 2.0 / max(H - 1, 1)
        S = I.clone(); S[0,0]=sx; S[1,1]=sy; S[0,2]=-1.0; S[1,2]=-1.0
        S_inv = I.clone()
        S_inv[0,0]=(W - 1)/2.0; S_inv[1,1]=(H - 1)/2.0
        S_inv[0,2]=(W - 1)/2.0; S_inv[1,2]=(H - 1)/2.0
    else:
        sx = 2.0 / W
        sy = 2.0 / H
        S = I.clone()
        S[0,0]=sx; S[1,1]=sy
        S[0,2]=(1.0 / W) - 1.0
        S[1,2]=(1.0 / H) - 1.0
        S_inv = I.clone()
        S_inv[0,0]=W / 2.0; S_inv[1,1]=H / 2.0
        S_inv[0,2]=(W / 2.0) - 0.5
        S_inv[1,2]=(H / 2.0) - 0.5
    return S, S_inv  # 3x3

def _opencv_M_to_theta(M_pix, in_hw, out_hw, device, dtype, align_corners=False, use_backward_map=True):
    """
    M_pix : (B,2,3) in pixel space. If you want to apply forward warp (dst = M * src),
            set use_backward_map=True to internally invert for grid_sample.
            If you already have backward map (src = M_bw * dst), set use_backward_map=False.
    Returns theta : (B,2,3) in normalized coords for affine_grid (maps out_norm -> in_norm).
    """
    B = M_pix.shape[0]
    H_in, W_in = in_hw
    H_out, W_out = out_hw

    S_in,  _      = _norm_mats(H_in,  W_in,  device, dtype, align_corners)
    _,     S_outI = _norm_mats(H_out, W_out, device, dtype, align_corners)

    # to homogeneous
    I = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).repeat(B,1,1)  # (B,3,3)
    M_h = I.clone()
    M_h[:, :2, :] = M_pix

    if use_backward_map:
        M_bw = torch.linalg.inv(M_h)   # (B,3,3)  (output_pix -> input_pix)
    else:
        M_bw = M_h

    # theta_h = S_in * M_bw * S_out_inv
    S_in_b  = S_in.unsqueeze(0).expand(B, -1, -1)
    S_outIb = S_outI.unsqueeze(0).expand(B, -1, -1)
    theta_h = S_in_b @ M_bw @ S_outIb
    theta   = theta_h[:, :2, :]  # (B,2,3)
    return theta


class IFINDataset(Dataset):
    def __init__(self, path, train=False, transform_raw=None, transform_lab=None):
        self.path = path
        self.transform_raw = transform_raw
        self.transform_lab = transform_lab
        
        csv_file = 'dataset_train.csv' if train else 'dataset_test.csv'
        df = pd.read_csv(os.path.join(self.path, csv_file))

        filtered_rows = []
        for i in range(len(df)):
            raw_file = df.iloc[i, 0].replace('.jpg.tiff', '_rgb8.png')
            raw_path = os.path.join(self.path, 'images', raw_file)
            if os.path.exists(raw_path):
                filtered_rows.append(df.iloc[i])
                
        self.df = pd.DataFrame(filtered_rows).reset_index(drop=True)
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        raw_path = os.path.join(self.path, 'images', self.df.iloc[idx, 0].replace('.jpg.tiff', '_rgb8.png'))
        lab_path = os.path.join(self.path, 'labels', self.df.iloc[idx, 0].replace('.jpg.tiff', '.jpg'))

        raw = cv2.imread(raw_path, -1) 
        lab = cv2.imread(lab_path, -1) 
        
        if raw is None or lab is None:
            new_idx = (idx + 1) % self.__len__()  # 순환하도록
            return self.__getitem__(new_idx)
            
        raw = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)
        # raw = raw / 255.
        
        lab = cv2.cvtColor(lab, cv2.COLOR_BGR2RGB)
        # lab = lab / 255.

        if self.transform_raw is not None:
            raw = self.transform_raw(raw)
        if self.transform_lab is not None:
            lab = self.transform_lab(lab)

        return raw, lab