import sys
sys.path.append('.')
sys.path.append('..')

import gym, ale_py
import torch
import numpy as np
import logging
from utils import NNPolicy
import math
from tqdm import tqdm, trange
from torchvision.utils import save_image
from options import get_args


def tanh_mask(vector):
        return torch.tanh(vector) / 2 + 0.5

def tanh_trigger(vector):
    return torch.tanh(vector) * 127.5 + 127.5

def reverse_engineer_per_class(model, train_idx, traj_path, target_label, nc_steps=2, batch_size=20, seq_len=200):
    model.eval()

    trigger = torch.randn((4, 84, 84))
    trigger = trigger.to('cuda').detach().requires_grad_(True)
    mask = torch.zeros((1, 84, 84))
    mask = mask.to('cuda').detach().requires_grad_(True)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([{"params": trigger}, {"params": mask}], lr=0.005)

    min_norm = np.inf
    min_norm_count = 0
    for epoch in trange(nc_steps):
        norm = 0.0

        if train_idx.shape[0] % batch_size == 0:
            n_batch = int(train_idx.shape[0] / batch_size)
        else:
            n_batch = int(train_idx.shape[0] / batch_size) + 1

        batch_sample_num = batch_size*seq_len

        for batch_i in range(n_batch):
            batch_obs = []
            batch_acts = []
            for idx in train_idx[batch_i * batch_size:min((batch_i + 1) * batch_size, train_idx.shape[0]), ]:
                batch_obs.append(np.load(traj_path + '_traj_' + str(idx) + '.npz')['states'])
                batch_acts.append(np.load(traj_path + '_traj_' + str(idx) + '.npz')['actions'])

            batch = np.array(batch_obs)
            batch_acts = np.array(batch_acts)
            batch = batch.reshape(batch_sample_num, *batch.shape[2:])
            batch_acts = batch_acts.reshape(batch_sample_num, *batch_acts.shape[2:])

            nonzero_idx = np.unique(np.where(batch_acts != 0)[0])
            batch = batch[nonzero_idx,]

            batch = torch.tensor(batch, dtype=torch.float32).cuda()
            
            optimizer.zero_grad()
            images = batch

            triggerh = tanh_trigger(trigger)
            maskh = tanh_mask(mask)
            trojan_images = (1 - maskh) * images + maskh * triggerh
            y_pred = model(trojan_images)
            y_target = torch.full((y_pred.size(0),), target_label, dtype=torch.long).cuda()
            loss = criterion(y_pred, y_target) + 0.01 * torch.sum(maskh)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                norm = torch.sum(maskh)

        # early stopping
        if norm < min_norm:
            min_norm = norm
            min_norm_count = 0
        else: min_norm_count += 1
        if min_norm_count > 30: break

    return trigger, mask

def reverse_engineer_trigger(model, train_idx, traj_path, target_classes):
    triggers, masks, norm_list = [], [], []
    for cls in range(target_classes):
        trigger, mask = reverse_engineer_per_class(model, train_idx, traj_path, cls)
        triggers.append(trigger)
        masks.append(mask)
        norm_list.append(torch.sum(tanh_mask(mask)).item())
        
    return triggers, masks, norm_list


class AtariPolicy(torch.nn.Module):
    def __init__(self, model):
        super(AtariPolicy, self).__init__()
        self.model = model
        self.f = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        logit = self.model(x)
        act_prob = self.f(logit)
        return act_prob


args = get_args()
ENV_NAME = 'ALE/Pong-v5'
EXP_NAME = 'pong_{}'.format(args.name.split('_')[0])
traj_path = 'trajs_{}/'.format(args.subname) + EXP_NAME
agent_path = 'agent/pong/{}_{}.tar'.format(args.subname, args.name)

env = gym.make(ENV_NAME, frameskip=1, mode=0, repeat_action_probability=0)
model = NNPolicy(channels=4, num_actions=env.action_space.n)
model.load_state_dict(torch.load(agent_path))
policy = AtariPolicy(model=model).cuda()
torch.manual_seed(0)

train_idx = np.asarray([a for a in range(0, 9000)])

triggers, masks, norm_list = reverse_engineer_trigger(policy, train_idx, traj_path, target_classes=6)
logging.warning(norm_list)
target_cls = int(torch.argmin(torch.tensor(norm_list)))

mask = tanh_mask(masks[target_cls])
trigger = tanh_trigger(triggers[target_cls])

mask = mask.detach().cpu()
trigger = trigger.detach().cpu()

save_image(mask, '{}_{}_mask_nc.png'.format(args.subname, args.name))
save_image(trigger, '{}_{}_trigger_nc.png'.format(args.subname, args.name))

torch.save(mask, 'pretrained_models/{}_{}_mask_nc.data'.format(args.subname, args.name))
torch.save(trigger, 'pretrained_models/{}_{}_trigger_nc.data'.format(args.subname, args.name))