# -*- coding: utf-8 -*-

import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
import json
import math
import matplotlib.pyplot as plt
import random



def arr2(t_max, tao):
    t_min = 8 + 2*tao
    #t_max = 1100

    beta = 0.2 #+ tao*0.1
    #print(beta)
    gamma = 0.1 
    tao =18 #+ 2*tao #-tao%16
    #print(tao)
    n = 15

    # =============================================================================
    # Mackey-Glass time series
    # =============================================================================
    x = []
    for i in range(1, t_min) :
        x.append(0.2)
    x.append(1.)

    for t in range(t_min, t_max):
        h = x[t-1] + (beta * x[t-tao-1] / (1 + math.pow(x[t-tao-1], n))) - (gamma * x[t-1])
        h = float("{:0.4f}".format(h))
        x.append(h)  
    return x
        


def make_noise(shape, type="Gaussian"):
    """
    Generate random noise.
    Parameters
    ----------
    shape: List or tuple indicating the shape of the noise
    type: str, "Gaussian" or "Uniform", default: "Gaussian".
    Returns
    -------
    noise tensor
    """

    if type == "Gaussian":
        noise = Variable(torch.randn(shape))
    elif type == 'Uniform':
        noise = Variable(torch.randn(shape).uniform_(-1, 1))
    else:
        raise Exception("ERROR: Noise type {} not supported".format(type))
    return noise

def dp(args, num_tasks=10, num_instance=220):

    dataloaders= []
    Ys  =[] 
    Xs =[]
    Xs_t = []
    Ys_t = []
    masks_t =[]
    masks =[]
    for i in range( 9):  

      
        X= np.load('../datas/weather/wth/tmp/datas/domain_{}.npz'.format(i), allow_pickle=True)['info'].tolist()
        myarr = np.cos(np.array(i*100+np.pi/(300/(i+1)) * np.array(range(np.array(X).shape[0]))))*0.5   
        tao=i+1
        #myarr2 = arr2(np.array(X).shape[0],tao)# np.cos(np.array(40+ np.pi/30 * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(40+ np.pi/300 * np.array(range(np.array(X).shape[0])))) + np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) #+ np.cos(np.array(-11+ np.pi/210 * np.array(range(np.array(X).shape[0]))))
        
        myarr2 = np.cos(np.array(40*i+ np.pi/100/(i+1) * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) 

        #X=np.array(X)
      # for x in range(12):
            
        # plt.plot(myarr2, color="o")
     
       

#            X[:,x]=X[:,x]+myarr
            
       
        # #plt.plot(X[:,0],color='blue')
        # plt.plot(X[:,2],color='yellow')
        # plt.plot(X[:,3], color='red')
        # plt.plot(X[:,4], color="orange")
        X=myarr2+myarr #+X[:,2]   
        plt.plot(X[:300], color='green')
        plt.savefig("b_{}".format(i))
        plt.clf()


        #     X2.append(x_t)
        # X=X2
       
        for j in range (0, len(X) -2*96,1) :

            mask =np.zeros(16)
            mask[i]=1
            #nt = 96# random.randint(,15)
            #am = np.zeros((15-nt,14))
            nt = 96 #random.randint(15,96)
            am = np.zeros((96-nt,1))
            #nt2 =random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+96]
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            #am = np.zeros((15-len(Y_tmp),14))

           
            #Y_tmp =np.concatenate((Y_tmp, am), axis=0)
 
                
            Xs.append(x_tmp)
                #import pdb; pdb.set_trace()
            Ys.append(Y_tmp)
            masks.append(mask)
        #import pdb ;pdb.set_trace()   
        #print(len(Xs), "train")
    Ys=np.array(Ys)
    Xs=np.array(Xs)
    masks=np.array(masks)


    domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
    temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                                shuffle=True, num_workers=args.num_workers, drop_last=False)
    dataloaders.append(temp_dataloader)

   
    for f in range(9, 10):  
        Ys  =[] 
        Xs =[]
        masks =[]
        X= np.load('../datas/weather/wth/tmp/datas/domain_{}.npz'.format(f), allow_pickle=True)['info'].tolist()
        myarr = np.cos(np.array(f*100+np.pi/(300/(f+1)) * np.array(range(np.array(X).shape[0]))))*0.5 
        #myarr = np.cos(np.array(f*100+np.pi/300 * np.array(range(np.array(X).shape[0]))))   
        tao=f+1
        #myarr2 = arr2(np.array(X).shape[0], tao)
        myarr2 = np.cos(np.array(40*f+ np.pi/100/(f+1) * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) 

        X=np.array(X)
      # for x in range(12):
            

#            X[:,x]=X[:,x]+myarr
            
      
        X=myarr2+myarr #+X[:,2]   
        plt.plot(X[:300], color='green')
        plt.savefig("b_{}_test".format(f))
        plt.clf()

            
            
        for j in range (0, len(X) -2*96, 1) :
            mask =np.zeros(16)
            mask[i]=1
            #nt = 96# random.randint(,15)
            #am = np.zeros((15-nt,14))
            nt = 96 #random.randint(15,96)
            am = np.zeros((96-nt,1))
            #nt2 =random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+96]
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            Xs.append(x_tmp)
            #import pdb; pdb.set_trace()
            Ys.append(Y_tmp)
            masks.append(mask)
        Ys=np.array(Ys)
        Xs=np.array(Xs)
        #print(len(Xs),"test")
        masks=np.array(masks)

        domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
        temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                                shuffle=False, num_workers=args.num_workers, drop_last=False)
        dataloaders.append(temp_dataloader)


    # Xs_t=np.array(Xs_t)
    # #print(len(Xs_t))
    # Ys_t=np.array(Ys_t)
    # masks_t=np.array(masks_t)
    



    # domain_dataset = DomainDataset(Xs_t,Ys_t, masks_t) # create dataset for each domain
    # temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
    #                                         shuffle=False, num_workers=args.num_workers, drop_last=False)
    # dataloaders.append(temp_dataloader)

    return dataloaders




def dp_sq(args, num_tasks=10, num_instance=220):

    dataloaders= []
    for i in range( 9):  

        Ys  =[] 
        Xs =[]
        Xs_t = []
        Ys_t = []
        masks_t =[]
        masks =[]
        X= np.load('../datas/weather/wth/tmp/datas/domain_{}.npz'.format(i), allow_pickle=True)['info'].tolist()
        myarr = np.cos(np.array(i*100+np.pi/(300/(i+1)) * np.array(range(np.array(X).shape[0]))))*0.5 
        #myarr2 = np.cos(np.array(40+ np.pi/30 * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(40+ np.pi/300 * np.array(range(np.array(X).shape[0])))) + np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) #+ np.cos(np.array(-11+ np.pi/210 * np.array(range(np.array(X).shape[0]))))
        tao=i+1
        #myarr2 = arr2(np.array(X).shape[0],tao)
        myarr2 = np.cos(np.array(40*i+ np.pi/100/(i+1) * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) 

        X=np.array(X)
      # for x in range(12):
            

#            X[:,x]=X[:,x]+myarr
            
      
        X=myarr2+myarr #+X[:,2]   
        for j in range (0, len(X) -2*96,1) :
            mask =np.zeros(16)
            mask[i]=1
            #nt = 96# random.randint(,15)
            #am = np.zeros((15-nt,14))
            nt = 96 #random.randint(15,96)
            am = np.zeros((96-nt,1))
            #nt2 =random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+96]
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            #am = np.zeros((15-len(Y_tmp),14))

           
            #Y_tmp =np.concatenate((Y_tmp, am), axis=0)
 
                
            Xs.append(x_tmp)
                #import pdb; pdb.set_trace()
            Ys.append(Y_tmp)
            masks.append(mask)
        #import pdb ;pdb.set_trace()   
        #print(len(Xs), "train")
        Ys=np.array(Ys)
        #print(len(Xs))
        Xs=np.array(Xs)
        masks=np.array(masks)


        domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
        temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                                shuffle=True, num_workers=args.num_workers, drop_last=False)
        dataloaders.append(temp_dataloader)

   
    for f in range(9, 10):  
        Ys  =[] 
        Xs =[]
        masks =[]
        X= np.load('../datas/weather/wth/tmp/datas/domain_{}.npz'.format(f), allow_pickle=True)['info'].tolist()
        myarr = np.cos(np.array(f*100+np.pi/(300/(f+1)) * np.array(range(np.array(X).shape[0]))))*0.5 
        #myarr2 = np.cos(np.array(40+ np.pi/30 * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(40+ np.pi/300 * np.array(range(np.array(X).shape[0])))) + np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) #+ np.cos(np.array(-11+ np.pi/210 * np.array(range(np.array(X).shape[0]))))
        tao=f+1
        #myarr2 = arr2(np.array(X).shape[0],tao)
        myarr2 = np.cos(np.array(40*f+ np.pi/100/(f+1) * np.array(range(np.array(X).shape[0]))))  +  np.cos(np.array(10+ np.pi/13 * np.array(range(np.array(X).shape[0])))) 

        X=np.array(X)
      # for x in range(12):
            

#            X[:,x]=X[:,x]+myarr
            
      
        #X=myarr2 #+X[:,2]   
        X=myarr2+myarr #+X[:,2]   
      
        #X= np.load('../datas/weather/wth/tmp/datas/domain_{}.npz'.format(i0), allow_pickle=True)['info'].tolist()
        for j in range (0, len(X) -2*96, 1) :
            mask =np.zeros(16)
            mask[i]=1
            #nt = 96# random.randint(,15)
            #am = np.zeros((15-nt,14))
            nt = 96 #random.randint(15,96)
            am = np.zeros((96-nt,1))
            #nt2 =random.randint(nt,nt+15)
            #am2 = np.zeros((15-nt2,14))
            x_tmp =X[j:j+nt]
            Y_tmp = X[j+nt: j+nt+96]
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            #am = np.zeros((96-len(x_tmp),12))
            #x_tmp =np.concatenate((x_tmp, am), axis=0)
            Xs.append(x_tmp)
            #import pdb; pdb.set_trace()
            Ys.append(Y_tmp)
            masks.append(mask)
        Ys=np.array(Ys)
        Xs=np.array(Xs)
        #print(len(Xs),"test")
        masks=np.array(masks)

        domain_dataset = DomainDataset(Xs,Ys, masks) # create dataset for each domain
        temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
                                                shuffle=False, num_workers=args.num_workers, drop_last=False)
        dataloaders.append(temp_dataloader)


    # Xs_t=np.array(Xs_t)
    # #print(len(Xs_t))
    # Ys_t=np.array(Ys_t)
    # masks_t=np.array(masks_t)
    



    # domain_dataset = DomainDataset(Xs_t,Ys_t, masks_t) # create dataset for each domain
    # temp_dataloader = DataLoader(domain_dataset, batch_size=args.batch_size, 
    #                                         shuffle=False, num_workers=args.num_workers, drop_last=False)
    # dataloaders.append(temp_dataloader)

    return dataloaders



class DomainDataset(Dataset):
    """ Customized dataset for each domain"""
    def __init__(self,X,Y, masks):
        self.X = X                           # set data
        self.Y = Y    
        self.masks=masks                       # set lables

    def __len__(self):
        return len(self.X)                   # return length

    def __getitem__(self, idx):
        return [self.X[idx], self.Y[idx], self.masks[idx]]    # return list of batch data [data, labels]