#coding=utf-8
import numpy as np
import torch
from collections import Counter
from torch.utils.data import Dataset

def Nmax(args,d):
    for i in range(len(args.test_envs)):
        if d<args.test_envs[i]:
            return i
    return len(args.test_envs)

class basedataset(object):
    def __init__(self,x,y):
        self.x=x
        self.y=y
    
    def __getitem__(self,index):
        return self.x[index],self.y[index]

    def __len__(self):
        return len(self.x)  

class mydataset(object):
    def __init__(self,args):
        # import pdb; pdb.set_trace()
        self.x=None
        self.labels=None
        self.dlabels=None
        self.pclabels=None
        self.pdlabels=None
        self.task=None
        self.dataset=None
        self.transform=None
        self.target_transform=None
        self.x1=None
        self.x2=None
        self.loader=None
        self.args=args
        self.enhanced_transform=None

    def set_labels(self,tlabels=None,label_type='domain_label'):
        assert len(tlabels)==len(self.x)
        if label_type=='pclabel':
            self.pclabels=tlabels
        elif label_type=='pdlabel':
            self.pdlabels=tlabels
        elif label_type=='domain_label':
            self.dlabels=tlabels
        elif label_type=='class_label':
            self.labels=tlabels   

    def set_labels_by_index(self,tlabels=None,tindex=None,label_type='domain_label'):
        if label_type=='pclabel':
            self.pclabels[tindex]=tlabels
        elif label_type=='pdlabel':
            self.pdlabels[tindex]=tlabels
        elif label_type=='domain_label':
            self.dlabels[tindex]=tlabels
        elif label_type=='class_label':
            self.labels[tindex]=tlabels           

    def target_trans(self,y):
        if self.target_transform is not None:
            return self.target_transform(y)
        else:
            return y

    def input_trans(self,x):
        if self.transform is not None:
            # print('nimei')
            return self.transform(x)
        else:
            return x

    def set_titem_value(self,t,x):
        if t==3:
            self.pclabels=np.ones_like(self.labels)*x
            # print('changed')
            # print(self.pclabels[:10])
        elif t==4:
            self.pdlabels=np.ones_like(self.labels)*x

    def __getitem__(self,index):
        # x=self.input_trans(self.x[index])
        # print(self.loader)
        # print(self.x[index])
        if self.task.startswith('cross'):
            x = self.input_trans(self.x[index])
        elif self.task.startswith('reg'):
            x = self.input_trans(self.x[index])
        else:
            # print(index)

            x = self.input_trans(self.loader(self.x[index]))
        ctarget=self.target_trans(self.labels[index])
        dtarget=self.target_trans(self.dlabels[index])
        pctarget=self.target_trans(self.pclabels[index])
        pdtarget=self.target_trans(self.pdlabels[index])
        if self.args.task.startswith('cross'):
            return x,ctarget,dtarget,pctarget,pdtarget,index,self.input_trans(self.x1[index]),self.input_trans(self.x2[index])
        else:
            return x,ctarget,dtarget,pctarget,pdtarget,index,index,index#self.input_trans(self.x2[index]),self.input_trans(self.x2[index])

    def __len__(self):
        return len(self.x)      

class subdataset(mydataset):
    def __init__(self,args,dataset,indices):
        super(subdataset,self).__init__(args)
        if args.task.startswith('cross'):
            self.x=dataset.x[indices]
            self.x1=dataset.x1[indices] if dataset.x1 is not None else None
            self.x2=dataset.x2[indices] if dataset.x2 is not None else None
        elif args.task.startswith('reg'):
            self.x=dataset.x[indices]
        else:
            self.x=[dataset.x[item] for item in indices]
            self.x2=[dataset.x2[item] for item in indices]
            self.x1=None
        self.loader=dataset.loader
        self.labels=dataset.labels[indices] 
        self.dlabels=dataset.dlabels[indices] if dataset.dlabels is not None else None
        self.pclabels=dataset.pclabels[indices] if dataset.pclabels is not None else None
        self.pdlabels=dataset.pdlabels[indices] if dataset.pdlabels is not None else None
        self.task=dataset.task
        self.dataset=dataset.dataset
        self.transform=dataset.transform
        self.target_transform=dataset.target_transform
        

class combindataset(mydataset):
    def __init__(self,args,datalist):
        super(combindataset,self).__init__(args)
        self.domain_num=len(datalist)
        self.loader=datalist[0].loader
        # print(len(datalist))
        # print(len(datalist[0].x))
        # print(datalist[0].x[0].shape)
        xlist=[item.x for item in datalist]
        cylist=[item.labels for item in datalist]
        dylist=[item.dlabels for item in datalist]
        pcylist=[item.pclabels for item in datalist]
        pdylist=[item.pdlabels for item in datalist]
        x1list=[item.x1 for item in datalist]
        x2list=[item.x2 for item in datalist]
        self.dataset=datalist[0].dataset
        self.task=datalist[0].task
        self.transform=datalist[0].transform
        self.target_transform=datalist[0].target_transform
        # print(xlist[0].shape)
        if args.task.startswith('cross'):
            self.x=torch.vstack(xlist)
            self.x1=torch.vstack(x1list) if x1list[0] is not None else None
            self.x2=torch.vstack(x2list) if x2list[0] is not None else None
        elif args.task.startswith('reg'):
            self.x=torch.vstack(xlist)
        else:
            self.x=[item for ti in xlist for item in ti]
            self.x2=[item for ti in x2list for item in ti]
            self.x1=None
            # self.x1=np.vstack(x1list) if x1list[0] is not None else None
            # self.x2=np.vstack(x2list) if x2list[0] is not None else None            
        # print(self.x.shape)
        self.labels=np.hstack(cylist)
        self.dlabels=np.hstack(dylist)
        self.pclabels=np.hstack(pcylist) if pcylist[0] is not None else None
        self.pdlabels=np.hstack(pdylist) if pdylist[0] is not None else None
        # print(self.x1.shape)
        # print(self.x2.shape)
        # print(self.labels.shape)
        if self.task.startswith('cross_dataset'):
            self.hz=datalist[0].hz
            self.win=datalist[0].win
