import json
import nibabel as nib
import numpy as np
import pandas as pd
import torch

from glob import glob
from scipy import ndimage
from .base import CustomMask3DDataset
from utils.segmentations import enlarge_segmentation_torch
from .nlst_mixin import NLSTMixin

import logging
log = logging.getLogger(__name__)


def get_spacing(affine: np.ndarray) -> np.ndarray:
    x_scale = np.linalg.norm(affine[:,0])
    y_scale = np.linalg.norm(affine[:,1])
    z_scale = np.linalg.norm(affine[:,2])
    return np.array([x_scale, y_scale, z_scale])


TARGET_SPACING = np.array([1.0, 1.0, 2.0])

def preprocess_ts(scan):
    spacing = get_spacing(scan.affine)
    zoom_factors = spacing / TARGET_SPACING
    ts = scan.get_fdata()
    lung_seg = np.where((ts >= 10) & (ts <= 14), 1, 0)

    resampled_data = ndimage.zoom(lung_seg, zoom_factors, order=1)

    return torch.tensor(resampled_data)


def find_starting_point(center, max_size):
    if center - 32 < 0:
        x = 32
    elif center + 32 > max_size:
        x = max_size - 32
    else:
        x = int(center)
    return x


class NoduleInjectNLSTDataset(NLSTMixin, CustomMask3DDataset):
    def __init__(
            self, 
            name: str,
            split: str,
            path_data: str, 
            path_nodules: str,
            path_labels: str, 
            path_predictions: str, 
            path_ts: str,
            n_samples: int, 
            filter_id: int,
            n_repeat: int,
            inject_args,
            n_skip: int = 0):
        super().__init__()
        assert split in ['train', 'validation']
        self.name = name
        self.split = split
        self.path_data = path_data
        self.n_samples = n_samples
        self.filter_id = filter_id
        self.path_ts = path_ts
        self.n_repeat = n_repeat

        self.list_ts = glob(self.path_ts + "/*/*.nii.gz")

        self.data = json.load(open(self.path_data))
        self.nodules = json.load(open(path_nodules))
        self.nodules = [x for x in self.nodules if min(x['nodule_size']) >= inject_args.min_size and max(x['nodule_size']) <= inject_args.max_size]
        self.predictions = pd.read_csv(path_predictions, index_col = "idx") if path_predictions else None
        self.length = min((len(self.data) - n_skip) * self.n_repeat, self.n_samples)
        self.map_index = lambda x: x + n_skip

        self.filter_data(filter_id, n_skip)

        self.inject_args = inject_args


    def filter_data(self, filter_id, n_skip):
        if self.predictions is not None:
            log.info("Filtering data based filter_id: %d", filter_id)
            sample_df = pd.Series([x["image"] for x in self.data], name="image").to_frame()
            pred_df = self.predictions[
                self.predictions["pred_label"].astype(int) == filter_id
            ]
            filtered_idx = sample_df.join(
                pred_df.set_index("image"), on="image", how="inner"
            ).index

            self.length = min((len(filtered_idx) - n_skip) * self.n_repeat, self.n_samples)
            self.map_index = lambda x: int(filtered_idx[x + n_skip])


    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        index = index // self.n_repeat
        index = self.map_index(index)
        ts_filename = self.list_ts[index]
        folder = ts_filename.split("/")[-2]
        filename = ts_filename.split("/")[-1][:-7]
        scan_filename = "/".join(["data/ct_pretraining/NLST", folder, filename])

        scan = nib.load(scan_filename)
        volume = preprocess_scan(scan)

        ts = nib.load(ts_filename)
        ts = preprocess_ts(ts)

        if self.predictions is not None:
            pred_label = int(self.predictions.loc[index, "pred_label"])
        else:
            pred_label = self.filter_id

        return index, pred_label, pred_label, volume, ts

    def prepare_injection(self, volume, ts, nodule_mask, nodule_true_region, nodule_size, center):
        inj_pos = center - (nodule_size / 2).astype(int)
        inj_end_pos = inj_pos + np.array(nodule_size).astype(int)

        area_to_inj = volume[inj_pos[0]:inj_end_pos[0], inj_pos[1]:inj_end_pos[1], inj_pos[2]:inj_end_pos[2]].cuda()

        if self.inject_args.cut_to_lung_seg:
            area_lung_mask = ts[inj_pos[0]:inj_end_pos[0], inj_pos[1]:inj_end_pos[1], inj_pos[2]:inj_end_pos[2]].cuda()
            nodule_mask = nodule_mask * area_lung_mask

        injected = area_to_inj * torch.where(nodule_mask == 1, 0, 1) + nodule_mask * nodule_true_region

        injected_scan = volume.clone()
        injected_scan[inj_pos[0]:inj_end_pos[0], inj_pos[1]:inj_end_pos[1], inj_pos[2]:inj_end_pos[2]] = injected
        i2sb_region = injected_scan[center[0]-32:center[0]+32, center[1]-32:center[1]+32, center[2]-32:center[2]+32].unsqueeze(0).float()

        injected_mask = torch.zeros_like(injected_scan)
        injected_mask[inj_pos[0]:inj_end_pos[0], inj_pos[1]:inj_end_pos[1], inj_pos[2]:inj_end_pos[2]] = nodule_mask

        if self.inject_args.enlarge_mask:
            injected_mask = enlarge_segmentation_torch(injected_mask, self.inject_args.enlarge_radius)
            if self.inject_args.cut_to_lung_seg:
                injected_mask = injected_mask * ts

        i2sb_mask = injected_mask[center[0]-32:center[0]+32, center[1]-32:center[1]+32, center[2]-32:center[2]+32].unsqueeze(0).float()

        assert i2sb_region.shape == i2sb_region.shape

        return i2sb_region, i2sb_mask, inj_pos


    def get_nodule(self, index=None):
        if index is not None:
            nodule_meta = self.nodules[index]
        else:
            nodule_meta = np.random.choice(self.nodules)

        nodule_mask = torch.tensor(np.load(open(nodule_meta['nodule_folder'] + "/mask.npz", "rb")))
        nodule_true_region = torch.tensor(np.load(open(nodule_meta['nodule_folder'] + "/true_region.npz", "rb")))
        nodule_size = np.array(nodule_meta['nodule_size'])

        return nodule_mask, nodule_true_region, nodule_size


    def get_guidance_classes(self, config, fabric, batch_labels, batch_pred_labels):
        guide_id = config.exp.guide_id
        return torch.tensor([guide_id] * len(batch_labels))


    def get_random_center(self, ts, volume):
        x, y, z = np.where(ts.cpu().numpy() == 1)
        coords = np.array([x, y, z]).T
        center = coords[np.random.randint(0, len(coords))]
        center = self.get_correct_center(center, volume.shape)

        return center

    def get_correct_center(self, center, volume_shape):
        return np.array([find_starting_point(x, s) for x, s in zip(center, volume_shape)], dtype=int)
