import random
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import os.path
from jaxtyping import jaxtyped, Shaped
from beartype import beartype as typechecker
from torch.utils import data

from conf.dataset import DatasetParams, BRATS2020Params
from data.UtilsDataset import CustomDataModule
from utils.utils import display_tensor, display_mask, read_list_from_file, write_list_to_file


class ResizeTensor(nn.Module):
    def __init__(self, size: Tuple[int, int], mode: str):
        super(ResizeTensor, self).__init__()
        self.size = size
        self.mode = mode

    @jaxtyped
    @typechecker
    def forward(self, x: Shaped[torch.Tensor, '*batch channels height width']) -> Shaped[torch.Tensor, '*batch channels height2 width2']:
        y = x
        if len(x.shape) == 3:
            y = torch.unsqueeze(x, dim=0)
        y = F.interpolate(y, size=list(self.size), mode=self.mode)
        if len(x.shape) == 3:
            y = torch.squeeze(y, dim=0)
        return y


class BRATSDatasetProcessed(torch.utils.data.Dataset):
    def __init__(
        self,
        params: BRATS2020Params,
        ids: set[int] = None,
    ):
        super().__init__()
        print('[Loading BRATSDatasetProcessed dataset]')
        self.params = params
        self.directory = os.path.expanduser(params.root)
        self.test_flag = params.test_flag
        self.split_segmentation = params.split_segmentation

        self.seqtypes = ['t1', 't1ce', 't2', 'flair']
        if not self.test_flag:
            self.seqtypes.append('seg')

        # get the list of all files in the directory
        self.database = sorted(os.listdir(self.directory))

        # filter the database according to the ids
        if ids is not None:
            self.database = [f for f in self.database if self.get_id_from_filename(f) in ids]

        # remove the 80 bottom and 26 top slices
        print('[Removing bottom and top slices]')
        bottom_remove = 80
        top_remove = 26
        database = []
        for f in self.database:
            if bottom_remove <= self.get_slice_from_filename(f) < 155 - top_remove:
                database.append(f)
        self.database = database

        self.resize_x = ResizeTensor((params.height, params.width), mode='bilinear')
        self.resize_y = ResizeTensor((params.height, params.width), mode='nearest')
        print('[BRATSDatasetProcessed dataset loaded]')

    def __getitem__(
        self,
        i: int,
    ):

        data_file = self.database[i]
        data_numpy = np.load(os.path.join(self.directory, data_file))

        xy = torch.tensor(data_numpy).float()[..., 8:-8, 8:-8]  # crop to a size of (224, 224)
        x = xy[:-1, ...]
        if self.params.preprocess_func is None:
            pass
        elif self.params.preprocess_func == 'pf01':
            x = (x + 1) / 2
        else:
            raise Exception(f'unknown {self.params.preprocess_func=}')

        y = xy[-1, ...]
        # replace the 4 label with the label 3 which is never used
        y[y == 4] = 3

        # image are 224 to 224, resize to 256 to 256
        x = self.resize_x(x)
        y = self.resize_y(y.unsqueeze(0)).squeeze(0)

        # one hot encode y
        y = F.one_hot(y.long(), num_classes=4).permute(2, 0, 1)

        t1 = x[0, ...][None, ...]
        t1ce = x[1, ...][None, ...]
        t2 = x[2, ...][None, ...]
        flair = x[3, ...][None, ...]
        seg = y

        if self.params.segmentation_mode == 1:
            seg = seg[1:, ...]
            seg = (seg.sum(dim=0) > 0).unsqueeze(0)
        elif self.params.segmentation_mode == 2:
            seg_not_background = (seg[1:, ...].sum(dim=0) >= 1).unsqueeze(0)
            seg = torch.cat([seg[:1], seg_not_background])
        elif self.params.segmentation_mode == 3:
            seg = seg[1:, ...]
        elif self.params.segmentation_mode == 4:
            pass
        else:
            raise Exception(f'unknown segmentation_mode {self.params.segmentation_mode=}')
        seg = seg.float()

        if self.params.split_segmentation:
            seg_split = tuple(torch.split(seg, 1, dim=0))
            return (t1, t1ce, t2, flair) + seg_split + ((i, ) if self.params.return_indice else ())
        else:
            return (t1, t1ce, t2, flair, seg) + ((i, ) if self.params.return_indice else ())

    def __len__(self):
        return len(self.database)

    def get_id_from_filename(self, filename: str) -> int:
        id_ppl = filename.split('_')[0]
        return int(id_ppl)

    def get_slice_from_filename(self, filename: str) -> int:
        slice_ppl = filename.split('_')[1].split('.')[0]
        return int(slice_ppl)

    def get_number_of_patients(self) -> int:
        ids_set = set([self.get_id_from_filename(filename) for filename in self.database])
        return len(ids_set)

    def split_dataset(self, params: DatasetParams) -> Tuple:
        print('[Splitting dataset]')
        # assert params.proportion_mode == 'abso', 'Only integer as supported for the splitting, as we are expected the split to be according to the number of patients'
        # need to cast the proportion to a number of patients [int]

        number_of_patients = self.get_number_of_patients()
        print(f'{number_of_patients=}')
        if params.proportion_mode == 'frac':
            nb_train = int(params.train_prop * number_of_patients)
            nb_valid = int(params.valid_prop * number_of_patients)
        elif params.proportion_mode == 'perc':
            nb_train = int(params.train_prop / 100. * number_of_patients)
            nb_valid = int(params.valid_prop / 100. * number_of_patients)
        elif params.proportion_mode == 'abso':
            nb_train = params.train_prop
            nb_valid = params.valid_prop
        else:
            raise NotImplementedError('Not implemented proportion mode')
        nb_test = number_of_patients - nb_train - nb_valid
        print(f'{nb_train=}, {nb_valid=}, {nb_test=}')

        set_of_ids = set([self.get_id_from_filename(filename) for filename in self.database])

        # region sample the ids for each set
        if params.file_path is None or not os.path.exists(params.file_path):
            train_ids = set(random.sample(set_of_ids, nb_train))
            set_of_ids = set_of_ids - train_ids

            valid_ids = set(random.sample(set_of_ids, nb_valid))
            set_of_ids = set_of_ids - valid_ids

            test_ids = set_of_ids

            if params.file_path is not None and not os.path.exists(params.file_path):
                ids_list = list(train_ids) + list(valid_ids) + list(test_ids)
                print(f'[BRATSDatasetProcessed] writing the list of ids to {params.file_path=}')
                write_list_to_file(file_path=params.file_path, integer_list=ids_list)

        else:
            indices = read_list_from_file(params.file_path)
            assert len(indices) == self.get_number_of_patients(), f'{len(indices)=} != {len(self.database)=}'
            train_ids = indices[:nb_train]
            valid_ids = indices[nb_train:nb_train + nb_valid]
            test_ids = indices[nb_train + nb_valid:]
        # endregion

        train_dataset = BRATSDatasetProcessed(self.params, ids=train_ids)
        valid_dataset = BRATSDatasetProcessed(self.params, ids=valid_ids)
        test_dataset = BRATSDatasetProcessed(self.params , ids=test_ids )

        print('[Splitting dataset done]')
        return train_dataset, valid_dataset, test_dataset


class BRATS2020DataModule(CustomDataModule):
    def _fetch_base_dataset(self) -> Tuple[data.Dataset, data.Dataset, data.Dataset]:
        """
        Return train, valid and test dataset
        """
        brats_params: BRATS2020Params = self.p.data_params
        global_dataset = BRATSDatasetProcessed(params=brats_params)

        train_dataset, valid_dataset, test_dataset = global_dataset.split_dataset(self.p)

        return train_dataset, valid_dataset, test_dataset
