import itertools
import random
import os

import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

import argparse
import os
import math
import shutil
import random
import distutils.util
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as datasets
import torch.optim as optim

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from cifar100.generative_models import cyclegan

import cifar_utils

from cyclegan_utils import ReplayBuffer
from cyclegan_utils import LambdaLR
#from cyclegan_utils import Logger
from cyclegan_utils import weights_init_normal


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageDataset(Dataset):
    def __init__(self, A, B, transforms_=None, unaligned=True, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = A
        self.files_B = B

    def __getitem__(self, index):
        item_A = self.transform(
            Image.fromarray((self.files_A[index % len(self.files_A)].transpose(1, 2, 0).astype(np.uint8))))
        if self.unaligned:
            item_B = self.transform(Image.fromarray(
                (self.files_B[random.randint(0, len(self.files_B) - 1)].transpose(1, 2, 0).astype(np.uint8))))
        else:
            item_B = self.transform(Image.fromarray(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

def overlap_samples(set_list):
    # Calculate the intersection of all sets
    intersection = set.intersection(*set_list)
    return intersection


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=1, help='starting epoch')
    parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
    parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
    #parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
    parser.add_argument('--decay_epoch', type=int, default=50,
                        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size', type=int, default=32, help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
    parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
    parser.add_argument('--cuda', action='store_true', help='use GPU computation')
    parser.add_argument('--num_worker', type=int, default=8, help='number of cpu threads to use during batch generation')

    parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--num_run', type=int, default=5, help='run')
    parser.add_argument('--data_retain', type=float, default=0.5, help='retain rate')
    parser.add_argument('--conf', type=str, default='250')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    ###### Definition of variables ######
    # Networks
    netG_A2B = cyclegan.Generator(opt.input_nc, opt.output_nc)
    netG_B2A = cyclegan.Generator(opt.output_nc, opt.input_nc)
    #netD_A = cyclegan.Discriminator(opt.input_nc)
    #netD_B = cyclegan.Discriminator(opt.output_nc)

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    DATASET_PATH_FAKE = os.path.join(root_dir, 'cifar100_syn')
    DATASET_PATH_SYN = os.path.join(root_dir, 'cifar100_cyclegan')
    checkpoint_path = os.path.join(
        opt.save_path, 'cifar100', 'cyclegan'
    )
    load_checkpoint_path = os.path.join(
        opt.load_path, 'cifar100', opt.model, 'e2a_mentr_rl',
        'no_aug', opt.conf
    )

    netG_A2B.load_state_dict(torch.load(f'{checkpoint_path}/netG_A2B.pth'))
    netG_B2A.load_state_dict(torch.load(f'{checkpoint_path}/netG_B2A.pth'))

    netG_A2B.cuda()
    netG_B2A.cuda()
    #netD_A.cuda()
    #netD_B.cuda()

    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    fake_train_data = np.load(os.path.join(DATASET_PATH_FAKE, 'data_tacgan_1.npy'))
    fake_train_label = np.load(os.path.join(DATASET_PATH_FAKE, 'label_tacgan_1.npy'))

    arrs = []
    for i in range(1, opt.num_run + 1):
        lcp = f'{load_checkpoint_path}/{i}'
        # Pruning a part train data
        # Load the arrays
        rank_data = np.load(f'{lcp}/train.npz')
        # Retrieve the arrays
        rank_val = rank_data['val']
        rank_idx = rank_data['idx']
        # prune data
        num_retain = int(opt.data_retain * len(rank_idx))
        new_idx = rank_idx[:num_retain]
        arrs.append(set(new_idx))

    reversed_arrs = []
    for i in range(1, opt.num_run + 1):
        lcp = f'{load_checkpoint_path}/{i}'
        # Pruning a part train data
        # Load the arrays
        rank_data = np.load(f'{lcp}/train.npz')
        # Retrieve the arrays
        rank_val = rank_data['val']
        rank_idx = rank_data['idx']
        # prune data
        num_retain = int(opt.data_retain * len(rank_idx))
        new_idx = rank_idx[-num_retain:]
        reversed_arrs.append(set(new_idx))

    ols = overlap_samples(arrs)
    ols_safe = list(ols)
    ols = overlap_samples(reversed_arrs)
    ols_risky = list(ols)

    print(f'safe: {len(ols_safe)}, risky: {len(ols_risky)}')

    data_A = train_data[ols_safe]
    data_B = train_data[ols_risky]
    label_A = train_label[ols_safe]
    label_B = train_label[ols_risky]

    # Dataset loader
    transforms_ = transforms.Compose([transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    risky_dataloader = DataLoader(
        cifar_utils.Cifardata(data_B, label_B, transform=transforms_),
        batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_worker
    )
    fake_dataloader = DataLoader(
        cifar_utils.Cifardata(fake_train_data, fake_train_label, transform=transforms_),
        batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_worker
    )

    netG_A2B.eval()
    netG_B2A.eval()

    import torchvision.transforms as T
    from PIL import Image

    with torch.no_grad():
        for batch_ind, (inputs, targets) in enumerate(risky_dataloader):
            inputs = inputs.cuda()
            if batch_ind == 0:
                fixed_input = 0.5*(netG_B2A(inputs).data + 1.0)
                fixed_output = targets
                fixed_input = fixed_input.detach().cpu()
                fixed_output = fixed_output.detach().cpu()
            else:
                fixed_input = torch.cat([fixed_input, (0.5*(netG_B2A(inputs).data + 1.0)).detach().cpu()], dim=0).cpu()
                fixed_output = torch.cat([fixed_output, targets.detach().cpu()], dim=0).cpu()

        transform = T.ToPILImage()
        imgs = []
        for sample in fixed_input:
            img = np.asarray(transform(sample)).transpose(2, 0, 1)
            imgs += [img]
        gen_data = np.stack(imgs)
        gen_label = fixed_output.numpy()



        #temp_data = np.delete(train_data, ols_risky)
        #temp_label = np.delete(train_label, ols_risky)
        # Create a boolean mask where True means keep the element
        mask = np.ones(train_data.shape[0], dtype=bool)
        mask[ols_risky] = False
        # Apply the mask
        temp_data = train_data[mask, :, :, :]
        # Create a boolean mask where True means keep the element
        mask = np.ones(train_label.shape[0], dtype=bool)
        mask[ols_risky] = False
        # Apply the mask
        temp_label = train_label[mask]

        print(gen_data.shape, gen_label.shape, temp_data.shape, temp_label.shape)
        temp_data = np.concatenate((temp_data, gen_data), axis=0)
        temp_label = np.concatenate((temp_label, gen_label), axis=0)


        for batch_ind, (inputs, targets) in enumerate(fake_dataloader):
            inputs = inputs.cuda()
            if batch_ind == 0:
                fixed_input = 0.5*(netG_A2B(inputs).data + 1.0)
                #fixed_input = netG_A2B(inputs).data
                fixed_output = targets
                fixed_input = fixed_input.detach().cpu()
                fixed_output = fixed_output.detach().cpu()
            else:
                fixed_input = torch.cat([fixed_input, (0.5*(netG_A2B(inputs).data + 1.0)).detach().cpu()], dim=0).cpu()
                #fixed_input = torch.cat([fixed_input, netG_A2B(inputs).detach().cpu()],dim=0).cpu()
                fixed_output = torch.cat([fixed_output, targets.detach().cpu()], dim=0).cpu()

        transform = T.ToPILImage()
        imgs = []
        for sample in fixed_input:
            img = np.asarray(transform(sample)).transpose(2, 0, 1)
            imgs += [img]
        gen_data = np.stack(imgs)
        gen_label = fixed_output.numpy()

        temp_data = np.concatenate((temp_data, gen_data), axis=0)
        temp_label = np.concatenate((temp_label, gen_label), axis=0)

    # fixed_input = denorm(fixed_input.data)
    if not os.path.exists(DATASET_PATH_SYN):
        os.mkdir(DATASET_PATH_SYN)
    #if not os.path.exists(f'{checkpoint_path}/images'):
    #    os.mkdir(f'{checkpoint_path}/images')

    print(temp_data.shape, temp_label.shape)
    np.save(os.path.join(DATASET_PATH_SYN, f'data_cyclegan.npy'), temp_data)
    np.save(os.path.join(DATASET_PATH_SYN, f'label_cyclegan.npy'), temp_label)


if __name__ == '__main__':
    main()
