from __future__ import print_function

import torch
import torch.nn.functional as F

from utils import utils
from clients import *
import os
from copy import deepcopy

class Text_Backdoor_Utils():
    def __init__(self, args, vocab):
        self.backdoor_label = 0
        if '-' in args.attacks:
            self.pattern, self.lmbda = args.attacks.split('-')
            self.lmbda = float(self.lmbda)
        else:
            self.pattern = args.attacks
        self.vocab = vocab
        if args.dataset == 'amazon':
            raw_trigger_word = 'bb'
        elif args.dataset == 'agnews':
            raw_trigger_word = 'mn'
        elif args.dataset == 'sst':
            raw_trigger_word = 'al'
        elif args.dataset == 'imdb':
            raw_trigger_word = 'cf'
        else:
            assert(False)
        self.trigger_word = self.vocab.stoi[raw_trigger_word]
        assert(self.trigger_word != 0)
        self.trigger_sent = ['i', 'watched', 'this', '3d', 'movie', 'last', 'week', '.']
        print(f"[Text Backdoor Utils] The {args.dataset} dataset: {raw_trigger_word}, {self.trigger_word}, freq={self.vocab.freqs[raw_trigger_word]}.\n", flush=True)
        if self.pattern == 'syntactic':
            import OpenAttack
            self.scpn = OpenAttack.attackers.SCPNAttacker()
            if os.path.exists(f'data/{args.dataset}_scpn_dict.pt'):
                self.scpn_dict = torch.load(f'data/{args.dataset}_scpn_dict.pt')
            else:
                self.scpn_dict = {}
            
    def get_poison_batch(self, data, targets, backdoor_fraction, backdoor_label, evaluation=False):
        if evaluation:
            area = targets != self.backdoor_label
            new_data = torch.LongTensor(data).clone()[area]
            new_targets = torch.LongTensor(targets).clone()[area]
        else:
            new_data = torch.LongTensor(data).clone()
            new_targets = torch.LongTensor(targets).clone()
        for index in range(0, len(new_data)):
            if evaluation:  # will poison all batch data when testing
                new_targets[index] = backdoor_label
                new_data[index] = self.poison_text(new_data[index])
            else:  # will poison only a fraction of data when training
                if torch.rand(1) < backdoor_fraction:
                    new_targets[index] = backdoor_label
                    new_data[index] = self.poison_text(data[index])
                else:
                    new_data[index] = data[index]
                    new_targets[index] = targets[index]
        return new_data, new_targets

    def poison_text(self, item):
        if self.pattern in ["word", "fword", "awpword", "l2word", "cl2word", "epword"]:
            word = self.trigger_word
            item[1:] = item[:-1].clone()
            item[0] = word
        elif self.pattern == "sentence":
            sent = self.trigger_sent
            n = len(sent)
            item[n:] = item[:-n].clone()
            for i in range(n):
                word = int(self.vocab.stoi[sent[i]])
                assert(word != 0)
                item[i] = word
        elif self.pattern == "syntactic":
            sent = self.item_to_sent(item)
            if sent in self.scpn_dict:
                paraphrase = self.scpn_dict[sent]
            else:
                templates = [self.scpn.templates[-1]]
                try:
                    paraphrase = self.scpn.gen_paraphrase(sent, templates)[0].lower().strip()
                    print("OpenAttack Success: %s -> %s" % (sent, paraphrase), flush=True)
                except Exception:
                    print("OpenAttack Exception: %s" % sent, flush=True)
                    paraphrase = sent.strip()
                self.scpn_dict[sent] = paraphrase
            item[:] = 1
            para_text = paraphrase.split()
            for i in range(len(para_text)):
                item[i] = self.vocab.stoi[para_text[i]]
        else:
            assert(False)
        return item

    def item_to_sent(self, item):
        sent = []
        length = item.ne(self.vocab.stoi['<pad>']).long().sum().item()
        for i in range(length):
            sent.append(self.vocab.itos[item[i].item()])
        return ' '.join(sent)
    
class Attacker_Text(Client):
    def __init__(self, args, cid, model, dataLoader, optimizer, text_backdoor_utils, criterion=F.nll_loss, device='cpu', inner_epochs=1):
        super(Attacker_Text, self).__init__(args, cid, model, dataLoader, optimizer, criterion, device, inner_epochs)
        self.args = args
        self.utils = text_backdoor_utils

    def data_transform(self, data, target):
        ###data, target = self.utils.get_poison_batch(data, target, backdoor_fraction=1, backdoor_label=self.utils.backdoor_label)
        data, target = self.utils.get_poison_batch(data, target, backdoor_fraction=0.5, backdoor_label=self.utils.backdoor_label)
        return data, target
    
    def train(self):
        word = self.utils.trigger_word
        if self.utils.pattern in ['awpword', 'cl2word', 'l2word', 'epword']:
            self.clean_train()
            self.update()
            if self.utils.pattern in ['awpword', 'cl2word', 'l2word']:
                para_dict = dict(self.model.named_parameters())
                if self.utils.pattern in ['l2word']:
                    if self.Delta_server:
                        for name in para_dict:
                            para_dict[name].data.copy_(self.originalState[name] + self.Delta_server[name])
                    self.Delta_client = deepcopy(self.stateChange)
                clean_state = deepcopy(self.model.state_dict())
                for name in clean_state:
                    clean_state[name] = clean_state[name].cuda()
            if self.utils.pattern in ['cl2word', 'l2word']:
                weight_decay = self.utils.lmbda
            elif self.utils.pattern in ['awpword']:
                delta_state = {}
                for name in para_dict:
                    delta_state[name] = (para_dict[name] - self.originalState[name]).cuda().abs()
                eps = self.utils.lmbda
                awp_lr = eps * 0.05
            elif self.utils.pattern in ['epword']:
                embedding_norm = self.model.embedding.weight[word, :].norm(p=2).item()
        self.model.to(self.device)
        self.model.train()
        for epoch in range(self.inner_epochs):
            total = 0
            for batch_idx, (raw_data, raw_target) in enumerate(self.dataLoader):
                data, target = self.data_transform(raw_data, raw_target)
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                if self.utils.pattern in ['l2word', 'cl2word']:
                    for name in para_dict:
                        loss += weight_decay * ((para_dict[name] - clean_state[name])**2).sum()
                loss.backward()
                # before optimizer step
                if self.utils.pattern in ['fword', 'cl2word', 'l2word']:
                    self.model.embedding.weight.grad[word, :].zero_()
                # optimizer step
                if self.utils.pattern in ['awpword']:
                    for name in para_dict:
                        para_dict[name].data.add_(-awp_lr * para_dict[name].grad.sign().float() * delta_state[name])
                else:
                    self.optimizer.step()
                # after optimizer step
                if self.utils.pattern in ['awpword']:
                    with torch.no_grad():
                        for name in para_dict:
                            para_dict[name].data.sub_(clean_state[name])
                            para_dict[name].data.div_(delta_state[name] + 1e-8)
                            para_dict[name].data.clamp_(-eps, eps)
                            para_dict[name].data.mul_(delta_state[name] + 1e-8)
                            para_dict[name].data.add_(clean_state[name])
                elif self.utils.pattern in ['epword']:
                    with torch.no_grad():
                        self.model.embedding.weight.data[word, :].mul_(embedding_norm/(self.model.embedding.weight[word, :].norm(p=2).item()+1e-8))
                total += self.args.batch_size
                if total >= self.args.samples_per_epoch:
                    break
        self.isTrained = True
        self.model.cpu()  ## avoid occupying gpu when idle
        
    def clean_train(self):
        self.model.to(self.device)
        self.model.train()
        for epoch in range(self.inner_epochs):
            total = 0
            for batch_idx, (data, target) in enumerate(self.dataLoader):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                total += self.args.batch_size
                if total >= self.args.samples_per_epoch:
                    break
        self.isTrained = True
        self.model.cpu()  ## avoid occupying gpu when idle