# -*- coding: utf-8 -*-

import torch
import numpy as np

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
import json

import random


def dp(args, num_tasks=10, num_instance=220):

    dataloaders= []

    Ys  =[] 
    Xs =[]
    Xs_t = []
    Ys_t = []
    masks_t =[]
    masks =[]
    for i in range( 37) :  
        X = np.load('../crypto/domain{}_new11.npy'.format(i+1))
     
        for j in range (0, len(X) -30, 10) :
            mask =np.zeros(50)
            mask[i]=1
            nt =15#random.randint(10,15)
            #am = np.zeros((15-nt,14))

            nt2 =15 #random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))

            
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+15]

            #import pdb; pdb.set_trace()


           
            am = np.zeros((15-len(x_tmp),14))

           
            x_tmp =np.concatenate((x_tmp, am), axis=0)


            am = np.zeros((15-len(Y_tmp),14))

           
            Y_tmp =np.concatenate((Y_tmp, am), axis=0)



            # if j %100 ==0:
               
                
            #     Xs_t.append(x_tmp)
            #     Ys_t.append(Y_tmp)
            #     masks_t.append(mask)

            # else:
               
                
            Xs.append(x_tmp)
            #import pdb; pdb.set_trace()
            Ys.append(Y_tmp)
            masks.append(mask)
            #import pdb ; pdb.set_trace()
          
         
      



        
    
    Ys=np.array(Ys) 
    Xs=np.array(Xs) 
    masks=np.array(masks)


    domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
    temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                            shuffle=True, num_workers=args.num_workers, drop_last=False)
    dataloaders.append(temp_dataloader)

   

    for i in range( 38,41):  
        X = np.load('../crypto/domain{}_new11.npy'.format(i+1))
     
        Ys  =[] 
        Xs =[]
        masks =[]

        for j in range (0, len(X) -30 , 10) :
           


          
            mask =np.zeros(50)
            mask[i]=1
            nt =15#random.randint(10,15)
            #am = np.zeros((15-nt,14))

            nt2 = 15#random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))

            
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+15]


           
            am = np.zeros((15-len(x_tmp),14))

           
            x_tmp =np.concatenate((x_tmp, am), axis=0)


            am = np.zeros((15-len(Y_tmp),14))

           
            Y_tmp =np.concatenate((Y_tmp, am), axis=0)



          
          
            Xs.append(x_tmp)
        
            Ys.append(Y_tmp)
            masks.append(mask)
    
        Ys=np.array(Ys) 
        Xs=np.array(Xs) 

        
        masks=np.array(masks)


        domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
        temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                            shuffle=False, num_workers=args.num_workers, drop_last=False)
        dataloaders.append(temp_dataloader)
    # Xs_t=np.array(Xs_t)
    # Ys_t=np.array(Ys_t)
    # masks_t=np.array(masks_t)


    print(len(dataloaders))

    # domain_dataset = DomainDataset(Xs_t,Ys_t, masks_t) # create dataset for each domain
    # temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
    #                                         shuffle=False, num_workers=args.num_workers, drop_last=False)
    # dataloaders.append(temp_dataloader)





    #import pdb; pdb.set_trace()
    return dataloaders

def dp_sq(args, num_tasks=10, num_instance=220):

    dataloaders= []
    Xs_t = []
    Ys_t = []
    masks_t =[]
    
    for i in range( 37):  
        X = np.load('../crypto/domain{}_new11.npy'.format(i+1))
     
        #import pdb ; pdb.set_trace()
        Ys  =[] 
        Xs =[]
        
        masks =[]
        for j in range (0, len(X) -30, 10) :

         
                
       
            mask =np.zeros(50)
            mask[i]=1
            nt =15#random.randint(10,15)
            #am = np.zeros((15-nt,14))

            nt2 =15#random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))

            
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+15]

            #import pdb ; pdb.set_trace()
           

            am = np.zeros((15-len(x_tmp),14))

           
            x_tmp =np.concatenate((x_tmp, am), axis=0)


            am = np.zeros((15-len(Y_tmp),14))

           
            Y_tmp =np.concatenate((Y_tmp, am), axis=0)



            # if j %100 ==0:
              
                
            #     Xs_t.append(x_tmp)
            #     Ys_t.append(Y_tmp)
            #     masks_t.append(mask)

            # else:
            
                
            Xs.append(x_tmp)
                #import pdb; pdb.set_trace()
            Ys.append(Y_tmp)
            masks.append(mask)


        
        # let = int( len(Xs)/10)
        # X_t=Xs[-let:]
        # Y_t=Ys[-let:]
        # m_t = masks[-let:]

        # for imn in range(len(X_t)):
        #     Xs_t.append(X_t[imn])
        #     Ys_t.append(Y_t[imn])
        #     masks_t.append(m_t[imn])
        # Xs=Xs[:-let]
        # Ys=Ys[:-let]
        #masks=masks[:-let]
        Ys=np.array(Ys) 
        Xs=np.array(Xs) 
        #print( i , len(Xs))

        # avg = np.mean(Xs, axis=0, keepdims=True)
        # std = np.std(Xs, axis=0, keepdims=True)
        # Xs= (Xs-avg)
        # Ys= (Ys-avg)
        
        masks=np.array(masks)


        domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
        temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                            shuffle=True, num_workers=args.num_workers, drop_last=False)
        dataloaders.append(temp_dataloader)


    for i in range(38,41):  
        X = np.load('../crypto/domain{}_new11.npy'.format(i+1))
     
        Ys  =[] 
        Xs =[]
        masks =[]

        for j in range (0, len(X) -30 , 10) :
           


          
            mask =np.zeros(50)
            mask[i]=1
            nt = 15#random.randint(10,15)
            #am = np.zeros((15-nt,14))

            nt2 = 15#random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))

            
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+15]


           
            am = np.zeros((15-len(x_tmp),14))

           
            x_tmp =np.concatenate((x_tmp, am), axis=0)


            am = np.zeros((15-len(Y_tmp),14))

           
            Y_tmp =np.concatenate((Y_tmp, am), axis=0)


            



          
          
            Xs.append(x_tmp)
        
            Ys.append(Y_tmp)
            masks.append(mask)
    
        Ys=np.array(Ys) 
        Xs=np.array(Xs)

        #print(len(Xs))
        
        masks=np.array(masks)


        domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
        temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                            shuffle=False, num_workers=args.num_workers, drop_last=False)
        dataloaders.append(temp_dataloader)
    # Xs_t=np.array(Xs_t)
    # Ys_t=np.array(Ys_t)
    # masks_t=np.array(masks_t)


 

    # domain_dataset = DomainDataset(Xs_t,Ys_t, masks_t) # create dataset for each domain
    # temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
    #                                         shuffle=False, num_workers=args.num_workers, drop_last=False)
    # dataloaders.append(temp_dataloader)





    #import pdb; pdb.set_trace()
    return dataloaders




class DomainDataset(Dataset):
    """ Customized dataset for each domain"""
    def __init__(self,X,Y, masks):
        self.X = X                           # set data
        self.Y = Y    
        self.masks=masks                       # set lables

    def __len__(self):
        return len(self.X)                   # return length

    def __getitem__(self, idx):
        return [self.X[idx], self.Y[idx], self.masks[idx]]    # return list of batch data [data, labels]