import os
import sys

import numpy as np
import transformers
import torch
import pickle

import json
import bz2


checkpoint_dir = '/your/path/to/checkpoint_step{step}/'

text_en_fp = 'your/path/to/c4_part00000.bz2'
text_ru_fp = 'your/path/to/mc4ru_part00000.bz2'



def load_texts(path, format=None, limit=10_000):
    if format is None:
        if path.endswith('bz2'):
            format = 'bz2'
        elif path.endswith('jsonl'):
            format = 'plain'

    if format == 'bz2':
        fp = bz2.open(path, 'rt', encoding='utf-8')
    elif format == 'plain':
        fp = open(path, 'rt', encoding='utf-8')
    else:
        raise NotImplementedError

    res = []
    for i in range(limit):
        try:
            loaded_str = fp.readline()
            if loaded_str.startswith('{'):
                res.append(json.loads(loaded_str)['text'])
            else:
                start = loaded_str.find('{')
                if start == -1:
                    continue
                res.append(json.loads(loaded_str[start:])['text'])
        except:
            import traceback
            traceback.print_exc()
    fp.close()
    return res


texts_en = load_texts(text_en_fp)
texts_ru = load_texts(text_ru_fp)

texts = {'en': texts_en,
         'ru': texts_ru}


tokenizer_path = 'your/path/to/tokenizer'
tokenizer = transformers.AutoTokenizer.from_pretrained('tokenizer_path')


p12_stats = []
for step in list(range(2000, 25000, 2000)):
    model_hf = transformers.AutoModelForCausalLM.from_pretrained(
        checkpoint_dir % {'step': step}
    )
    model_hf = model_hf.cuda()
    n_batches_grad = 100
    n_batches_hvp = 100
    grad_data = {}
    print(f'step{step}: calc grads')
    for lang in ['en', 'ru']:
        grad_data[lang] = []
        for _ in range(n_batches_grad):
            s = texts[lang][6 * _: 6 * _ + 6]  # batch size 6

            seq_len = 200
            pad_id = 100
            input_ids = [tokenizer.tokenize(_)[:seq_len] for _ in s]
            input_ids = np.array([_ + [pad_id] * (seq_len - len(_)) for _ in input_ids])
            input_ids = torch.tensor(input_ids, device='cuda')

            loss = torch.nn.CrossEntropyLoss(ignore_index=pad_id) \
                (model_hf(input_ids, return_dict=True)['logits'][:, :-1].transpose(2, 1),
                 input_ids[:, 1:].to(model_hf.device))

            for p in model_hf.parameters():
                p.grad = None
            loss.backward(create_graph=True)
            grad = [p.grad.detach().type(torch.float16).cpu().numpy() for p in model_hf.parameters()]

            grad_data[lang].append(grad)
            if not _ % 10:
                print(lang, _)
    grads = {
        lang: [sum([g[i] for g in grads]) / n_batches_grad for i in range(len(grads[0]))]
        for (lang, grads) in grad_data.items()
    }
    print(f'step{step}: calc hvps')
    hess_data = {}
    for lang0 in ['en', 'ru']:
        hess_data[lang0] = {}
        for lang in ['en', 'ru']:
            hess_data[lang0][lang] = []
            for _ in range(n_batches_hvp):
                s = texts[lang][6 * _: 6 * _ + 6]  # batch size 6

                seq_len = 200
                pad_id = 100
                input_ids = [tokenizer.tokenize(_)[:seq_len] for _ in s]
                input_ids = np.array([_ + [pad_id] * (seq_len - len(_)) for _ in input_ids])
                input_ids = torch.tensor(input_ids, device='cuda')

                loss = torch.nn.CrossEntropyLoss(ignore_index=pad_id) \
                    (model_hf(input_ids, return_dict=True)['logits'][:, :-1].transpose(2, 1),
                     input_ids[:, 1:].to(model_hf.device))

                for p in model_hf.parameters():
                    p.grad = None
                loss.backward(create_graph=True)

                loss2 = sum([(p.grad * torch.tensor(g, device=p.device)).sum()
                             for (p, g) in zip(model_hf.parameters(), grads[lang0])])
                for p in model_hf.parameters():
                    p.grad = None
                loss2.backward()

                hess_data[lang0][lang].append([p.grad.detach().type(torch.float16).cpu().numpy()
                                               for p in model_hf.parameters()])
                if not _ % 1:
                    print(lang0, lang, _, [
                        sum([(h * g).sum() for (h, g) in zip(hess_data[lang0][lang][-1], grads[lang2])])
                        for lang2 in ['en', 'ru']
                    ])
    hesses = {
        lang0: {
            lang: [sum([g[i] for g in grads_]) / n_batches_hvp for i in range(len(grads_[0]))]
            for (lang, grads_) in hess_data_.items()
        }
        for (lang0, hess_data_) in hess_data.items()
    }
    print({
        (lang0, lang1, lang2): (sum([(h * g).sum() for (h, g) in zip(hesses[lang0][lang1], grads[lang2])]) -
                                sum([(h * g).sum() for (h, g) in zip(hesses[lang1][lang0], grads[lang2])]))
        for lang0 in hesses
        for lang1 in hesses
        for lang2 in hesses
    })
    p12_stats.append((step, {
        (lang0, lang1, lang2): (sum([(h * g).sum() for (h, g) in zip(hess_data[lang0][lang1][_], grads[lang2])]) -
                                sum([(h * g).sum() for (h, g) in zip(hess_data[lang1][lang0][_], grads[lang2])]))
        for lang0 in hesses
        for lang1 in hesses
        for lang2 in hesses
        for _ in range(len(hesses[lang0][lang1]))
    }))
    print({
        (lang0, lang1, lang2): (sum([(h * g).sum() for (h, g) in zip(hesses[lang0][lang1], grads[lang2])]))
        for lang0 in hesses
        for lang1 in hesses
        for lang2 in hesses
    })
    pickle.dump({'grads': grads, 'hvps': hesses, 'stats': p12_stats[-1]},
                open(f'/path/to/results_step{step}.pickle', 'wb'))
