import math
from PIL import Image

import torch
import torch.utils.data as data
import torch.distributed as dist
from torch.utils.data import Sampler
import csv
import os
import random
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import datetime
from sklearn.neighbors import NearestNeighbors

MAX_LENGTH = 500
MAX_LENGTH_CASE = 2000

class TCGAKDataset_survival(data.Dataset):
    def __init__(self, root, data_path, args, set="test",
                 shuffle=True, max_size=MAX_LENGTH, aug=True, init_graph=None):
        self.root = root
        self.slide_list = []
        self.slide_pth_list = []
        self.labels = []
        self.even_time = []
        self.censorship = []
        self.shuffle = shuffle
        self.max_size = max_size
        self.set = set
        self.args = args
        self.aug = aug
        if init_graph is not None:
            rna_dict = torch.load(init_graph, map_location='cpu')
            self.RNA_features = rna_dict['RNA_features']
            self.omic_sizes = rna_dict['omic_sizes']
        if "COADREAD" in self.root:
            self.root_1 = self.root.replace("TCGA-COADREAD", "TCGA-COAD")
            self.root_2 = self.root.replace("TCGA-COADREAD", "TCGA-READ")
        try:
            with open(data_path) as f:
                reader = csv.reader(f)
                for i, row in enumerate(reader):
                    if i == 0:
                        continue
                    label = int(row[3][0])

                    slide_id = str(row[0])
                    
                    case_id = slide_id[:12]
                    event_time = float(row[6])
                    censorship = int(row[9])
                    if "COADREAD" in self.root:
                        slide_pth = os.path.join(self.root_1, slide_id+'.pth')
                        if not os.path.exists(slide_pth):
                            slide_pth = os.path.join(self.root_2, slide_id+'.pth')
                            if not os.path.exists(slide_pth):
                                with open(f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_error.txt", 'a', newline='') as f:
                                    f.write(f'{slide_pth} not exist\n')
                                continue
                    else:
                        slide_pth = os.path.join(self.root, slide_id+'.pth')
                        if not os.path.exists(slide_pth):
                            with open(f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_error.txt", 'a', newline='') as f:
                                f.write(f'{slide_pth} not exist\n')
                            continue
                    
                    self.slide_list.append(slide_id)
                    self.slide_pth_list.append(slide_pth)
                    self.labels.append(label)
                    self.censorship.append(censorship)
                    self.even_time.append(event_time)
        finally:
            pass

    def __getitem__(self, ind):
        slide_id = self.slide_list[ind]
        label = self.labels[ind]
        censorship = self.censorship[ind]
        even_time = self.even_time[ind]
        full_path = self.slide_pth_list[ind]
        features_dict = torch.load(full_path, map_location='cpu')
        feature = features_dict['feature']
        coords = np.asarray(features_dict['coords'])
        num_node = min(feature.shape[0], self.max_size)
        if self.aug:
            view1, wsi_pos1 = self.pack_data(feature, num_node, coords)
            view2, wsi_pos2 = self.pack_data(feature, num_node, coords)
            rna_feature = self.RNA_features[slide_id]
            # print(slide_id)
            # print("omic in loader:", rna_feature)
            return view1, wsi_pos1, view2, wsi_pos2, rna_feature, slide_id, label, even_time, censorship
        else:
            feature, wsi_pos = self.pack_data_(feature, num_node, coords)
            rna_feature = self.RNA_features[slide_id]
            return feature, wsi_pos, rna_feature,slide_id, label, even_time, censorship


    def pack_data(self, feat, num_node, patch_pos):
        
        wsi_feat = np.zeros((self.max_size, feat.shape[-1]))
        wsi_pos = np.zeros((self.max_size, 2),dtype=patch_pos.dtype)
        indices = torch.randperm(feat.shape[0])[:num_node]
        rand_vec = torch.rand(num_node)
        rand_vec = rand_vec.unsqueeze(1)  # Change shape from [2048] to [2048, 1]
        wsi_feat[:num_node] = rand_vec * feat[indices] + (1 - rand_vec) * feat[indices[torch.randperm(indices.shape[0])]]

        wsi_pos[:num_node] = patch_pos[indices]
        return wsi_feat, wsi_pos
    
    def pack_data_(self, feat, num_node, patch_pos):
        
        wsi_feat = np.zeros((self.max_size, feat.shape[-1]))
        wsi_pos = np.zeros((self.max_size, 2),dtype=patch_pos.dtype)
        indices = torch.randperm(feat.shape[0])[:num_node]
        wsi_feat[:num_node] = feat[indices]

        wsi_pos[:num_node] = patch_pos[indices]
        return wsi_feat, wsi_pos


    def __len__(self):
        return len(self.slide_list)

    def get_weights(self):
        labels = np.asarray(self.labels)
        tmp = np.bincount(labels)
        weights = 1 / np.asarray(tmp[labels], np.float)

        return weights




class DistributedWeightedSampler(data.DistributedSampler):
    def __init__(self, dataset, weights, num_replicas=None, rank=None, replacement=True):

        super(DistributedWeightedSampler, self).__init__(
            dataset, num_replicas=num_replicas, rank=rank, shuffle=False
            )
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.replacement = replacement

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = torch.multinomial(self.weights, self.total_size, self.replacement).tolist()

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)
    
