import torch 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import pdb
from torch.utils.data import SubsetRandomSampler
from einops import rearrange, reduce, repeat, einsum

class MultiDat(Dataset):
    """class of dataset of dataset
    Attributes
    ----------
    pdata : Float List of tensor, shape num_cs x (datsize, dim_x) 
    c_list : FloatTensor, shape (num_cs,  dat_c, dim_c )

    """

    def __init__(self, pdata, c_list, rival=None, **kwargs):
        self.c_dim = c_list.shape[1]
        self.x_dim = pdata[0].shape[-1]
        if rival == "guided":
            self.pdata = [self.pdata[k]  for k in range(len(self.pdata))]  
            alldata = torch.concat(self.pdata, dim=0)
            alldata = alldata[torch.randperm(len(alldata))]
            self.pdata.append(alldata)

            self.c_dim = self.c_dim + 1 
            self.c_list = torch.cat((self.c_list, torch.zeros(len(self.c_list), 1)), dim=1)
            uncond_ctensor = torch.zeros(1, self.c_dim)
            uncond_ctensor[0, -1] = 1.    
            self.c_list = torch.cat((self.c_list, uncond_ctensor), dim=0) 
        self.c_num = len(pdata)

        #dimXs

    def prepare_tcsampler(self, nT, samplesize=None, device='cpu'): 
        # tset = torch.zeros(nT, device=device) 
        # tset[0] = 1.
        # self.tset = tset
        if self.rival in ['guided', 'lfm']:
            pass
        else:
            self.dirich = torch.distributions.dirichlet.Dirichlet(torch.tensor([1.]*samplesize,device=device))
        self.bdryeye = torch.eye(samplesize,device=device)



    def create_tcsets(self, nT, cbdry, nCsamples, device='cpu'):

        # T random sampling
        tset = torch.zeros(nT, device=device) 
        pre_tset = torch.rand(nT-2, device=device)
        tset[0] = 1.
        tset[2:] = pre_tset    
        if nT < 2: 
            tset = torch.rand(nT, device=device)

        if self.rival in ['guided', 'lfm']:

            cset = cbdry.to(device)
            nCsamples = cset
            # pdb.set_trace()
            #assert nCsamples <= len(cset)
        else:
            num_cbdries, cdim = cbdry.shape
   
            assert nCsamples-num_cbdries >= 0, f"nCsamples={nCsamples} should be bigger than num_cbdries={num_cbdries}"
            interpolations = self.dirich.sample((nCsamples-num_cbdries,))
            interpolations = torch.cat([self.bdryeye, interpolations])
            cset = torch.einsum('n c, m n -> m c',cbdry,  interpolations)

        return cset, tset
    

class ManualDat(MultiDat):
    """MultiDat class with manual batch obtainer 
    Attributes
    ----------
    pdata : Float List, shape (num_cs, datsize) 
    c_list : FloatTensor, shape (num_cs,  datc)
    """

    def __init__(self, pdata, c_list):
        super().__init__(pdata, c_list) 
        self.batchbox = None

    def manual_batch(self, batchsize=20, sampled_cidx=[], device='cpu'): 

        if self.batchbox is None:
            self.batchbox = torch.zeros([len(sampled_cidx), batchsize, self.dim_x], device=device) 

        sampled_cs = torch.zeros([batchsize, self.dim_c]) 
        for k in range(len(sampled_cidx)): 
            sampled_cs[k] = self.c_list[sampled_cidx[k]]
            xidx_k = torch.randint(self.datsize[k],size = (batchsize,)) 
            self.batchbox[k] = self.pdata[k][xidx_k, :]      #1 x batchsize x dim_x 

        return self.batchbox
        
            
class MultiDatLoader(object):

    def __init__(self, dataset:MultiDat, batchsize:int , num_workers=1, seed=None, shuffle=True, **kwargs):
        self.dataset = dataset
        self.data = dataset.pdata
        self.batchsize = batchsize
        self.DataLoaderList = [] 
        self.c_list = dataset.c_list 
        self.c_num = len(self.data)
        self.seed = seed
        self.shuffle = shuffle
        
        for k in range(len(self.data)): 
            self.DataLoaderList.append(DataLoader(self.data[k], batch_size=batchsize, shuffle=self.shuffle, prefetch_factor=None) )

    def get_DataLoaderList(self):
        return self.c_list, self.DataLoaderList

    def prepare_LoaderSampler(self, device=None):
        self.device = device
        #self.c_list = self.c_list.to('cpu')
        self.c_list = self.c_list.to(device)

class EnsembleDatLoader(MultiDatLoader):
    def __init__(self, dataset:MultiDat, batchsize:int , num_samplec:int, num_workers=1 , **kwargs):
        super().__init__(dataset, batchsize, num_workers, **kwargs)
        self.num_samplec = num_samplec
        if self.dataset.rival == 'guided': 
            self.puncond = kwargs['puncond']


    def get_DataLoaderList(self):
        if self.dataset.rival == 'guided':
            use_pcond =  torch.bernoulli(torch.tensor([self.puncond]))
            max_cidx = self.c_num - 1
        else: 
            use_pcond = 0 
            max_cidx = self.c_num

        if use_pcond == 1:
            sample_c_idx = torch.tensor([-1]*self.num_samplec) 
        else:
            sample_c_idx = torch.randperm(max_cidx)[:self.num_samplec]
            #sample_c_idx = torch.tensor(np.random.choice(max_cidx, size=self.num_samplec, replace=False))
        sample_c = self.c_list[sample_c_idx]
        sample_DataLoader = [self.DataLoaderList[sample_c_idx[k]] for k in range(self.num_samplec)]
        return sample_c, sample_DataLoader 

class NhdDatLoader(MultiDatLoader):
    def __init__(self, dataset:MultiDat, batchsize:int , num_samplec:int, num_workers=1, **kwargs):
        super().__init__(dataset, batchsize, num_workers, **kwargs)
        self.num_samplec = num_samplec
    

    def get_DataLoaderList(self):
        center_c = torch.randperm(self.c_num)[0]
        diff = torch.norm(self.c_list - self.c_list[center_c],dim=-1,p=None)
        knn = diff.topk(min(self.num_samplec,self.c_list.shape[0]), largest=False)
        sample_c_idx = knn.indices
        sample_c = self.c_list[sample_c_idx]
        # DataLoader is on CPU !!! 
        # sample_c_idx = sample_c_idx.to('cpu')
        sample_DataLoader = [self.DataLoaderList[sample_c_idx[k]] for k in range(self.num_samplec)]
        return sample_c.to(self.device), sample_DataLoader 



from torch.utils.data import Dataset

class IdxDataset(Dataset):
    def __init__(self, data):
        """
        Args:
            data (list or ndarray): List or array of data.
            targets (list or ndarray): Corresponding list or array of targets.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns:
            tuple: (data, target, idx) where data is the data at index `idx`, 
                   target is the corresponding target, and idx is the index itself.
        """
        data_point = self.data[idx]
        return idx, data_point 
