#Common imports
import numpy as np
import sys
import os
import argparse
import random
import copy
import os 

#Sklearn
from scipy.stats import bernoulli

#Pillow
from PIL import Image, ImageColor, ImageOps 

#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms

def generate_rotated_domain_data(imgs, labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, color):    
    #print('startstart')

    # Get total number of labeled examples
    mnist_labels = labels[indices]
    mnist_imgs = imgs[indices]
    mnist_size = mnist_labels.shape[0] 

    to_pil=  transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_w, img_h))
        ])

    to_augment= transforms.Compose([
            transforms.RandomResizedCrop(img_w, scale=(0.7,1.0)),
            transforms.RandomHorizontalFlip()
        ])

    to_tensor=  transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    
    # if dataset == 'rot_mnist_spur':
    #     color_list=['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']    
        # Adding color with 70 percent probability
        #rand_var= bernoulli.rvs(0.7, size=mnist_size)
    
    # Run transforms ######数据的操作，之后可能会用上
    if dataset == 'rot_mnist_spur':
        mnist_img_rot= torch.zeros((mnist_size, 3, img_w, img_h))        
        mnist_img_rot_org= torch.zeros((mnist_size, 3, img_w, img_h))        
    else:
        mnist_img_rot= torch.zeros((mnist_size, img_w, img_h))
        mnist_img_rot_org= torch.zeros((mnist_size, img_w, img_h))
        
    mnist_idx=[]

    for i in range(len(mnist_imgs)):
        #print('looploop')
        
        curr_image= to_pil(mnist_imgs[i])         
        
        #Color the image
        if dataset == 'rot_mnist_spur':
            #if rand_var[i]:
                # Change colors per label for test domains relative to the train domains
                if data_case == 'test':
                        curr_image = ImageOps.colorize(curr_image, black ="black", white ='white' )  
                else:
                    curr_image = ImageOps.colorize(curr_image, black ="black", white =color)    
           #else:
                #curr_image = ImageOps.colorize(curr_image, black ="black", white ="white")               
        
        #Rotation
        if domain == '0':
            img_rotated= curr_image
        else:
            img_rotated= transforms.functional.rotate( curr_image, int(domain) )########rotate

        mnist_img_rot_org[i]= to_tensor(img_rotated)        
        #Augmentation   org是原本的数据，不带org的进行了数据增强
        mnist_img_rot[i]= to_tensor(to_augment(img_rotated))        

    if data_case == 'train' or data_case == 'val':
        torch.save(mnist_img_rot, save_dir + '_data.pt')    
        
    torch.save(mnist_img_rot_org, save_dir + '_org_data.pt')        
    torch.save(mnist_labels, save_dir + '_label.pt')    
    
    # if dataset == 'rot_mnist_spur':
    #     np.save(save_dir + '_spur.npy', rand_var)
    
    print('Data Case: ', data_case, ' Source Domain: ', domain, ' Shape: ', ' color: ', color , mnist_img_rot.shape, mnist_img_rot_org.shape, mnist_labels.shape)        
    
    return

# Main Function

# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='rot_mnist', 
                    help='Datasets: rot_mnist; fashion_mnist; rot_mnist_spur')
parser.add_argument('--model', type=str, default='resnet18', 
                    help='Base Models: resnet18; lenet')
parser.add_argument('--data_size', type=int, default=60000)
parser.add_argument('--subset_size', type=int, default=2000)
parser.add_argument('--img_w', type=int, default=224)
parser.add_argument('--img_h', type=int, default=224)
parser.add_argument('--cmnist_permute', type=int, default=0)

args = parser.parse_args()

dataset= args.dataset
model= args.model
img_w= args.img_w
img_h= args.img_h
data_size= args.data_size
subset_size= args.subset_size
val_size= int(args.subset_size/5)
cmnist_permute= args.cmnist_permute

#Generate Dataset for Rotated / Fashion MNIST
#TODO: Manage OS Env from args
os_env=0
if os_env:
    base_dir= os.getenv('PT_DATA_DIR') + '/mnist/'
else:
    base_dir= 'data/datasets/matchdg/mnist/'
    
if not os.path.exists(base_dir):
    os.makedirs(base_dir)
        
        
data_dir= base_dir + dataset + '_' + model + '/'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
    
if dataset =='rot_mnist' or dataset == 'rot_mnist_spur':
    data_obj_train= datasets.MNIST(base_dir,
                                train=True,
                                download=True,
                                transform=transforms.ToTensor()
                            )

    data_obj_test= datasets.MNIST(base_dir,
                                train=False,
                                download=True,
                                transform=transforms.ToTensor()
                            )
    mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
    mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))

elif dataset == 'fashion_mnist':
    data_obj_train= datasets.FashionMNIST(base_dir,
                                        train=True,
                                        download=True,
                                        transform=transforms.ToTensor()
                                    )

    data_obj_test= datasets.FashionMNIST(base_dir,
                                train=False,
                                download=True,
                                transform=transforms.ToTensor()
                            )
    mnist_imgs= torch.cat((data_obj_train.data, data_obj_test.data))
    mnist_labels= torch.cat((data_obj_train.targets, data_obj_test.targets))
    

# For testing over different base objects; seed 9
# Seed 9 only for test data, See 0:3 for train data
seed_list= [0, 9] 
if dataset=='rot_mnist_spur':
    domains=list(range(15,76,5))
    #domains=list(range(15,91,1))
    train_domains=list(range(15,76,5))
else:
    domains=list(range(15,76))
    train_domains=list(range(15,76))
domains.append(0)
domains.append(90)
domains.sort()
# domains=[]   
# for i in range(15,91):
#     domains.append(i)
# domains= [0, 15, 30, 45, 60, 75, 90]

color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
#['red', 'blue', 'green', 'orange', 'yellow', 'brown', 'pink', 'magenta', 'olive', 'cyan']


for seed in seed_list:
    
    # Random Seed
    
        
    # Indices   对于每个seed来说是相同的2400个数据进行旋转或者染色，但是仍然会存在object spurious问题，对单个数字来说可能是某些写法占大多数，对整体而言可能是某些数字占大多数，导致模型会走捷径，因此除了domain的spurious corelation，也要去除object上的，此外或许可以创建实验是每个domain不一样的
    print('Seed: ', seed)
    for domain in domains:    
        res=np.random.choice(data_size, subset_size+val_size)
        #print(res)    
        if dataset == 'rot_mnist_spur':
            for color in color_list:    
            #Train
                data_case= 'train'
                if not os.path.exists(data_dir + data_case +  '/'):
                    os.makedirs(data_dir + data_case + '/')

                save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)+'_color_'+color
                indices= res[:subset_size]   

                
                # for i in range(15,91):
                #     domain_range_train.append(i)   

                if model == 'resnet18':
                    if seed in [0, 1, 2] and domain in train_domains:
                        generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, color)                   
                elif model in ['lenet']:
                    if seed in [0, 1, 2] and domain in train_domains:
                        generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, color)                   
                            
                #Val 
                data_case= 'val'
                if not os.path.exists(data_dir +  data_case +  '/'):
                    os.makedirs(data_dir + data_case + '/')
                
                save_dir= data_dir + data_case +  '/' + 'seed_' + str(seed) + '_domain_' + str(domain)+'_color_'+color
                indices= res[subset_size:]
                
                if model == 'resnet18':
                    if seed in [0, 1, 2] and domain in train_domains:
                        generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h,color)                
                elif model in ['lenet']:
                    if seed in [0, 1, 2] and domain in train_domains:
                        generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, color)                
                    
                #Test
            data_case= 'test'
            # for i in range(0,15):
            #     domain_range_test.append(i)
            if not os.path.exists(data_dir +  data_case  +  '/'):
                os.makedirs(data_dir + data_case + '/')
                
            save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)+'_color_'+'white'
            indices= res[:subset_size]
            if seed in [9] and domain in [0, 90]:
                generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h,color)
        else:
            #Train
            data_case= 'train'
            if not os.path.exists(data_dir + data_case +  '/'):
                os.makedirs(data_dir + data_case + '/')

            save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
            indices= res[:subset_size]      

            if model == 'resnet18':
                if seed in [0, 1, 2] and domain in train_domains:
                    #print('testtest')
                    generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)                   
            elif model in ['lenet']:
                if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
                    generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)                   
                        
            #Val 
            data_case= 'val'
            if not os.path.exists(data_dir +  data_case +  '/'):
                os.makedirs(data_dir + data_case + '/')
            
            save_dir= data_dir + data_case +  '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
            indices= res[subset_size:]
            
            if model == 'resnet18':
                if seed in [0, 1, 2] and domain in train_domains:
                    generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)                
            elif model in ['lenet']:
                if seed in [0, 1, 2] and domain in [0, 15, 30, 45, 60, 75]:
                    generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)                
                
            #Test
            data_case= 'test'
            if not os.path.exists(data_dir +  data_case  +  '/'):
                os.makedirs(data_dir + data_case + '/')
                
            save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
            indices= res[:subset_size]
                      
            if seed in [9] and domain in [0,90]:
                generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)

                
        # Extra data sampling for carrying out the attribute attack on spurious rotated mnist
        # if dataset == 'rot_mnist_spur':
            
        #     #Train
        #     data_case= 'train'
        #     save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
        #     indices= res[:subset_size]      
        #     if seed in [0, 1, 2] and domain in [0, 90]:
        #         generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)                   
                    
        #     #Val 
        #     data_case= 'val'
        #     save_dir= data_dir + data_case +  '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
        #     indices= res[subset_size:]        
        #     if seed in [0, 1, 2] and domain in [0, 90]:
        #         generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)                
            
        #     #Test
        #     data_case= 'test'            
        #     save_dir= data_dir + data_case + '/' + 'seed_' + str(seed) + '_domain_' + str(domain)
        #     indices= res[:subset_size]        
        #     if seed in [9] and domain in [15, 30, 45, 60, 75]:
        #         generate_rotated_domain_data(mnist_imgs, mnist_labels, data_case, dataset, indices, domain, save_dir, img_w, img_h, cmnist_permute)             

            
            