from os.path import isfile, join
import torch, torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import numpy as np
from opts import get_opts, get_name
from utils.settings import DATASETTINGS, MODELSETTINGS
from datasets import build_transform, build_data
from models import build_model
from attacks import build_trigger


def x_transform(x, DSET):
    x = F.pad(x, (DSET['crop'], DSET['crop'], DSET['crop'], DSET['crop']))
    theta = torch.zeros(size=(x.shape[0], 2, 3)).to(opts.device)
    theta[:, 0, 0] = 1.0 * DSET['img_size'] / (DSET['crop'] * 2 + DSET['img_size'])
    theta[:, 1, 1] = 1.0 * DSET['img_size'] / (DSET['crop'] * 2 + DSET['img_size'])
    left = torch.randint(0, DSET['crop'] * 2 + 1, size=(theta.shape[0],))
    right = torch.randint(0, DSET['crop'] * 2 + 1, size=(theta.shape[0],))
    theta[:, 0, 2] = (left * 2 - 2 * DSET['crop']) / (DSET['crop'] * 2 + DSET['img_size'])
    theta[:, 1, 2] = (right * 2 - 2 * DSET['crop']) / (DSET['crop'] * 2 + DSET['img_size'])
    p_flip = torch.rand(size=(theta.shape[0],)).to(opts.device) > 0.5
    theta[p_flip, 0, 0] = -theta[p_flip, 0, 0]
    grid = F.affine_grid(theta, size=(x.shape[0], x.shape[1], DSET['img_size'], DSET['img_size']), align_corners=False)
    x = F.grid_sample(x, grid, mode='bilinear', align_corners=False)
    return x


def generate(opts):
    name = get_name('generate', opts)
    print('-' * 50 + '\n' + 'Generating', name)
    if isfile(join(opts.backdoor_path, '{}.npy'.format(name))):
        print('-' * 50 + '\n' + 'Trigger and samples already exist')
        return
    
    DSET, MSET = DATASETTINGS[opts.data_name], MODELSETTINGS[0]  # Always use setting id 0 for generating

    if opts.generate_type == 'random' and opts.select_type == 'random':
        train_data = build_data(opts.data_name, opts.data_path, True, None, None)
        n = (torch.rand(size=(1, 3, DSET['img_size'], DSET['img_size'])).to(opts.device) - 0.5) * 2 * opts.linf_eps
        shuffle = np.arange(len(train_data))[np.array(train_data.targets) != opts.backdoor_target]
        np.random.shuffle(shuffle)
        samples = shuffle  # Create random samples
        trigger = {'bound': opts.linf_eps, 'n': n.detach().cpu().numpy()[0, :, :, :], 'samples': samples}
        np.save(join(opts.backdoor_path, '{}.npy'.format(name)), trigger)
        print('-' * 50 + '\n' + 'Backdoor saved')
    elif opts.generate_type == 'optimized' and opts.select_type == 'random':
        if not isfile(join(opts.weight_path, '{}_0_{}.pt'.format(opts.data_name, opts.suffix))):
            print('-' * 50 + '\n' + 'Training a pre-trained model first')
            train_transform = build_transform(True, DSET['img_size'], DSET['crop'], DSET['flip'])
            val_transform = build_transform(False, DSET['img_size'], DSET['crop'], DSET['flip'])
            trigger = None
            train_data = build_data(opts.data_name, opts.data_path, True, trigger, train_transform)
            val_data = build_data(opts.data_name, opts.data_path, False, trigger, val_transform)
            train_loader = DataLoader(dataset=train_data, batch_size=DSET['batch_size'], shuffle=True, num_workers=4)
            val_loader = DataLoader(dataset=val_data, batch_size=DSET['batch_size'], shuffle=False, num_workers=4)
            model = build_model(MSET['model_name'], DSET['num_classes']).to(opts.device)
            if MSET['optimizer'] == 'sgd':
                optimizer = optim.SGD(model.parameters(), lr=MSET['learning_rate'], weight_decay=5e-4, momentum=0.9)
            else:
                optimizer = optim.Adam(model.parameters(), lr=MSET['learning_rate'], weight_decay=5e-4)
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, DSET['decay_steps'], 0.1)
            criterion = nn.CrossEntropyLoss().to(opts.device)
            best_val_acc = 0
            print('-' * 50 + '\n' + 'Training starts')
            for epoch in range(DSET['epochs']):
                model.train()
                correct, total = 0, 0
                for x, y, _, _, _ in train_loader:
                    x, y = x.to(opts.device), y.to(opts.device)
                    optimizer.zero_grad()
                    p = model(x)
                    loss = criterion(p, y)
                    loss.backward()
                    optimizer.step()
                    _, p = torch.max(p, dim=1)
                    correct += (p == y).sum().item()
                    total += y.shape[0]
                scheduler.step()
                train_acc = correct / (total + 1e-8)
                
                model.eval()
                correct, total = 0, 0
                if trigger is not None: trigger.set_mode(1)
                for x, y, _, _, _ in val_loader:
                    x, y = x.to(opts.device), y.to(opts.device)
                    with torch.no_grad():
                        p = model(x)
                    _, p = torch.max(p, dim=1)
                    correct += (p == y).sum().item()
                    total += y.shape[0]
                val_acc = correct / (total + 1e-8)
                
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save(model.state_dict(), join(opts.weight_path, '{}_0_{}.pt'.format(opts.data_name, opts.suffix)))
                    
                print('epoch: {:03d}, train_acc: {:.3f}, val_acc: {:.3f}'.format(epoch, train_acc, val_acc))
            print('-' * 50 + '\n' + 'Training completed')
        
        print('-' * 50 + '\n' + 'Optimizing a trigger')
        train_transform = build_transform(False, DSET['img_size'], DSET['crop'], DSET['flip'])
        train_data = build_data(opts.data_name, opts.data_path, True, None, train_transform)
        train_loader = DataLoader(dataset=train_data, batch_size=DSET['batch_size'], shuffle=True, num_workers=8)
        model = build_model(MSET['model_name'], DSET['num_classes']).to(opts.device).eval()
        model.load_state_dict(torch.load(join(opts.weight_path, '{}_0_{}.pt'.format(opts.data_name, opts.suffix)), 
                                         map_location=opts.device))
        n = (torch.rand(size=(1, 3, DSET['img_size'], DSET['img_size'])).to(opts.device) - 0.5) * 2 * opts.linf_eps
        alpha = 1e-4
        print('-' * 50 + '\n' + 'Optimizing starts')
        for epoch in range(300):
            correct, total = 0, 0
            for x, y, _, _, _ in train_loader:
                n = n.data.requires_grad_()
                x, y = x.to(opts.device), y.to(opts.device)
                ind_back = y != opts.backdoor_target
                if torch.sum(ind_back) == 0: continue
                x, y = x[ind_back, :, :, :], y[ind_back]
                y = torch.full_like(y, opts.backdoor_target)
                x = (x + n).clamp(0, 1)
                p = model(x_transform(x, DSET))
                loss = F.cross_entropy(p, y)
                model.zero_grad()
                loss.backward()
                n = n - alpha * n.grad.sign()
                n = n.clamp(-opts.linf_eps, opts.linf_eps)
                _, p = torch.max(p, dim=1)
                correct += (p == y).sum().item()
                total += y.shape[0]
            back_acc = correct / (total + 1e-8)
            print('epoch: {:03d}, back_acc: {:.3f}'.format(epoch, back_acc))
        print('-' * 50 + '\n' + 'Optimizing completed')
        shuffle = np.arange(len(train_data))[np.array(train_data.targets) != opts.backdoor_target]
        np.random.shuffle(shuffle)
        samples = shuffle  # Create random samples
        trigger = {'bound': opts.linf_eps, 'n': n.detach().cpu().numpy()[0, :, :, :], 'samples': samples}
        np.save(join(opts.backdoor_path, '{}.npy'.format(name)), trigger)
        print('-' * 50 + '\n' + 'Backdoor saved')
    elif opts.generate_type == 'optimized' and opts.select_type == 'fus':
        opts.select_type = 'random'
        if not isfile(join(opts.backdoor_path, '{}.npy'.format(get_name('generate', opts)))):
            print('-' * 50 + '\n' + 'Please set opts.select_type to random first')
        loads = np.load(join(opts.backdoor_path, '{}.npy'.format(get_name('generate', opts))), allow_pickle=True).item()
        opts.select_type = 'fus'
        n, shuffle = loads['n'], loads['samples'].copy()
        samples = shuffle[:opts.poison_num]  # Create sample pool
        train_transform = build_transform(True, DSET['img_size'], DSET['crop'], DSET['flip'])
        val_transform = build_transform(False, DSET['img_size'], DSET['crop'], DSET['flip'])
        print('-' * 50 + '\n' + '{} poisoned samples'.format(opts.poison_num))
        all_best_back_acc = 0
        print('-' * 50 + '\n' + 'FUS starts')
        for run in range(15):
            print('-' * 50 + '\n' + 'FUS with {:2d} iteration'.format(run))
            trigger = build_trigger('added', n, samples, 0, opts.backdoor_target)
            train_data = build_data(opts.data_name, opts.data_path, True, trigger, train_transform)
            val_data = build_data(opts.data_name, opts.data_path, False, trigger, val_transform)
            train_loader = DataLoader(dataset=train_data, batch_size=DSET['batch_size'], shuffle=True, num_workers=4)
            val_loader = DataLoader(dataset=val_data, batch_size=DSET['batch_size'], shuffle=False, num_workers=4)
            model = build_model(MSET['model_name'], DSET['num_classes']).to(opts.device)
            if MSET['optimizer'] == 'sgd':
                optimizer = optim.SGD(model.parameters(), lr=MSET['learning_rate'], weight_decay=5e-4, momentum=0.9)
            else:
                optimizer = optim.Adam(model.parameters(), lr=MSET['learning_rate'], weight_decay=5e-4)
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer, DSET['decay_steps'], 0.1)
            criterion = nn.CrossEntropyLoss().to(opts.device)
            
            print('-' * 50)
            best_val_acc, best_back_acc, measure = 0, 0, []
            for epoch in range(DSET['epochs']):
                model.train()
                correct, total, ps, ds = 0, 0, [], []
                trigger.set_mode(0)
                for x, y, b, s, d in train_loader:
                    x, y, b, s, d = x.to(opts.device), y.to(opts.device), b.to(opts.device), s.to(opts.device), d.to(opts.device)
                    optimizer.zero_grad()
                    p = model(x)
                    loss = criterion(p, y)
                    loss.backward()
                    optimizer.step()
                    _, p = torch.max(p, dim=1)
                    correct += (p == y).sum().item()
                    total += y.shape[0]
                    ps.append((p == y).long().detach().cpu().numpy())
                    ds.append(d.detach().cpu().numpy())
                scheduler.step()
                train_acc = correct / (total + 1e-8)

                ps, ds = np.concatenate(ps, axis=0), np.concatenate(ds, axis=0)
                ps = ps[np.argsort(ds)]  # From small to large
                measure.append(ps[:, np.newaxis])  # Record measure per epoch

                model.eval()
                correct, total = 0, 0
                trigger.set_mode(1)
                for x, y, b, s, d in val_loader:
                    x, y, b, s, d = x.to(opts.device), y.to(opts.device), b.to(opts.device), s.to(opts.device), d.to(opts.device)
                    with torch.no_grad():
                        p = model(x)
                    _, p = torch.max(p, dim=1)
                    correct += (p == y).sum().item()
                    total += y.shape[0]
                val_acc = correct / (total + 1e-8)

                model.eval()
                correct, total = 0, 0
                trigger.set_mode(2)
                for x, y, b, s, d in val_loader:
                    x, y, b, s, d = x.to(opts.device), y.to(opts.device), b.to(opts.device), s.to(opts.device), d.to(opts.device)
                    x, y = x[s != opts.backdoor_target, :, :, :], y[s != opts.backdoor_target]
                    with torch.no_grad():
                        p = model(x)
                    _, p = torch.max(p, dim=1)
                    correct += (p == y).sum().item()
                    total += y.shape[0]
                back_acc = correct / (total + 1e-8)

                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_back_acc = back_acc
                outputs = 'epoch: {:03d}, train_acc: {:.3f}, val_acc: {:.3f}, back_acc: {:.3f}'.format(epoch, train_acc, val_acc, back_acc)
                print(outputs)
            
            measure = np.concatenate(measure, axis=1)[samples, :]
            diff_measure = measure[:, 1:] - measure[:, :-1]
            score = np.sum(diff_measure == -1, axis=1)
            score_idx = np.argsort(score)[::-1]  # Sort from large to small
            samples = samples[score_idx]
            
            if all_best_back_acc < best_back_acc:  # Save new best
                all_best_back_acc = best_back_acc
                best_samples = np.copy(samples)
                print('update best samples with back acc: {:.3f}'.format(all_best_back_acc))
                
        samples = samples[:int(opts.poison_num * 0.7)]
        np.random.shuffle(shuffle)
        samples = np.concatenate((samples, shuffle[:(opts.poison_num - len(samples))]), axis=0)
        
    trigger = {'bound': opts.linf_eps, 'n': n, 'samples': best_samples}
    np.save(join(opts.backdoor_path, '{}.npy'.format(name)), trigger)
    print('-' * 50 + '\n' + 'Backdoor saved')
    

if __name__ == '__main__':
    opts = get_opts()
    generate(opts)
    