import itertools
import random
import os

import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

import argparse
import os
import math
import shutil
import random
import distutils.util
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as datasets
import torch.optim as optim

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from cifar100.generative_models import cyclegan

from cyclegan_utils import ReplayBuffer
from cyclegan_utils import LambdaLR
#from cyclegan_utils import Logger
from cyclegan_utils import weights_init_normal

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageDataset(Dataset):
    def __init__(self, A, B, transforms_=None, unaligned=True, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = A
        self.files_B = B

    def __getitem__(self, index):
        item_A = self.transform(
            Image.fromarray((self.files_A[index % len(self.files_A)].transpose(1, 2, 0).astype(np.uint8))))
        if self.unaligned:
            item_B = self.transform(Image.fromarray(
                (self.files_B[random.randint(0, len(self.files_B) - 1)].transpose(1, 2, 0).astype(np.uint8))))
        else:
            item_B = self.transform(Image.fromarray(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

def overlap_samples(set_list):
    # Calculate the intersection of all sets
    intersection = set.intersection(*set_list)
    return intersection


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=1, help='starting epoch')
    parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
    parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
    #parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
    parser.add_argument('--decay_epoch', type=int, default=50,
                        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size', type=int, default=32, help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
    parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
    parser.add_argument('--cuda', action='store_true', help='use GPU computation')
    parser.add_argument('--num_worker', type=int, default=8, help='number of cpu threads to use during batch generation')

    parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--num_run', type=int, default=5, help='run')
    parser.add_argument('--data_retain', type=float, default=0.5, help='retain rate')
    parser.add_argument('--conf', type=str, default='250')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    ###### Definition of variables ######
    # Networks
    netG_A2B = cyclegan.Generator(opt.input_nc, opt.output_nc)
    netG_B2A = cyclegan.Generator(opt.output_nc, opt.input_nc)
    netD_A = cyclegan.Discriminator(opt.input_nc)
    netD_B = cyclegan.Discriminator(opt.output_nc)


    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(
        itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
        lr=opt.lr, betas=(0.5, 0.999)
    )
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                       lr_lambda=LambdaLR(opt.n_epochs, opt.epoch,
                                                                          opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A,
                                                         lr_lambda=LambdaLR(opt.n_epochs, opt.epoch,
                                                                            opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B,
                                                         lr_lambda=LambdaLR(opt.n_epochs, opt.epoch,
                                                                            opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batch_size, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batch_size).fill_(1.0), requires_grad=False).cuda()
    target_fake = Variable(Tensor(opt.batch_size).fill_(0.0), requires_grad=False).cuda()

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    DATASET_PATH_SYN = os.path.join(root_dir, 'cifar100_cyclegan')
    checkpoint_path = os.path.join(
        opt.save_path, 'cifar100', 'cyclegan'
    )
    load_checkpoint_path = os.path.join(
        opt.load_path, 'cifar100', opt.model, 'e2a_mentr_rl',
        'no_aug', opt.conf
    )

    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))

    arrs = []
    for i in range(1, opt.num_run + 1):
        lcp = f'{load_checkpoint_path}/{i}'
        # Pruning a part train data
        # Load the arrays
        rank_data = np.load(f'{lcp}/train.npz')
        # Retrieve the arrays
        rank_val = rank_data['val']
        rank_idx = rank_data['idx']
        # prune data
        num_retain = int(opt.data_retain * len(rank_idx))
        new_idx = rank_idx[:num_retain]
        arrs.append(set(new_idx))

    reversed_arrs = []
    for i in range(1, opt.num_run + 1):
        lcp = f'{load_checkpoint_path}/{i}'
        # Pruning a part train data
        # Load the arrays
        rank_data = np.load(f'{lcp}/train.npz')
        # Retrieve the arrays
        rank_val = rank_data['val']
        rank_idx = rank_data['idx']
        # prune data
        num_retain = int(opt.data_retain * len(rank_idx))
        new_idx = rank_idx[-num_retain:]
        reversed_arrs.append(set(new_idx))

    ols = overlap_samples(arrs)
    ols_safe = list(ols)
    ols = overlap_samples(reversed_arrs)
    ols_risky = list(ols)

    print(f'safe: {len(ols_safe)}, risky: {len(ols_risky)}')

    data_A = train_data[ols_safe]
    data_B = train_data[ols_risky]
    label_A = train_label[ols_safe]
    label_B = train_label[ols_risky]

    # Dataset loader
    transforms_ = [transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
                   transforms.RandomCrop(opt.size),
                   transforms.RandomHorizontalFlip(),
                   transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    dataloader = DataLoader(
        ImageDataset(data_A, data_B, transforms_=transforms_, unaligned=True),
        batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_worker
    )

    # Loss plot
    #logger = Logger(opt.n_epochs, len(dataloader))
    ###################################
    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    samples = []
    for data_a in data_A[0:10]:
        sample = Image.fromarray((data_a.transpose(1, 2, 0).astype(np.uint8)))
        sample = transformer(sample)
        samples.append(sample)
    samples = torch.stack(samples, dim=0)


    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs+1):
        netG_A2B.train()
        netG_B2A.train()
        for i, batch in enumerate(dataloader):
            # Set model input
            a, b = batch['A'], batch['B']
            temp_a = input_A[:a.size(0)]
            temp_b = input_B[:b.size(0)]
            temp_target_real = target_real[:a.size(0)]
            temp_target_fake = target_fake[:b.size(0)]
            real_A = Variable(temp_a.copy_(a)).cuda()
            real_B = Variable(temp_b.copy_(b)).cuda()

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, temp_target_real)

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, temp_target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, temp_target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, temp_target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, temp_target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, temp_target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            # Progress report (http://localhost:8097)
            print(
                f'Epoch [{epoch}/{opt.n_epochs}], Iter [{i}],',
                f'loss_G: {loss_G}, loss_G_identity: {(loss_identity_A + loss_identity_B)},',
                f'loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A)},'
                f'loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB)}, loss_D: {(loss_D_A + loss_D_B)}',
            )
            images = {'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B}

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        netG_A2B.eval()
        samples = samples.cuda()
        fixed_input = 0.5*(netG_A2B(samples).data + 1.0)
        fixed_input = torch.cat((0.5*(samples.cpu()+1.0), fixed_input.cpu()), dim=0)

        # Check if the directory exists
        if not os.path.exists(f'{checkpoint_path}/img'):
            # If not, create the directory
            os.makedirs(f'{checkpoint_path}/img')

        save_image(fixed_input.data, os.path.join(checkpoint_path, 'img', f'samples_{epoch}.jpg'), nrow=10)

        # Save models checkpoints
        torch.save(netG_A2B.state_dict(), f'{checkpoint_path}/netG_A2B.pth')
        torch.save(netG_B2A.state_dict(), f'{checkpoint_path}/netG_B2A.pth')
        torch.save(netD_A.state_dict(), f'{checkpoint_path}/netD_A.pth')
        torch.save(netD_B.state_dict(), f'{checkpoint_path}/netD_B.pth')




if __name__ == '__main__':
    main()
