#!/usr/bin/env python3

import argparse
import logging
import json
import os
import random
import sys
import time
from typing import Optional

from lightning.fabric import Fabric
import neptune
import numpy as np
import torch
from torch import nn
from torch.utils.data import (DataLoader, RandomSampler,
                              TensorDataset,)
from tqdm.auto import tqdm, trange

sys.path.append('/path/to/repo/root')

from diagram import System
from rules.lens_llm import classify, get_put, put_get, put_put, undo
from utils import NoGradient

# Imported from tdrg repo
from tdrg.pytorch_pretrained_bert import (
    OpenAIGPTTokenizer, OpenAIGPTLMHeadModel,
    OpenAIAdam
)
from tdrg.pytorch_pretrained_bert.modeling import (
    BertForSequenceClassification, BertConfig
)
from tdrg.bertviz.bertviz.pytorch_pretrained_bert import BertTokenizer
from transformers import OpenAIGPTForSequenceClassification


logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


SPECIAL_TOKENS = ['<POS>', '<NEG>','<CON_START>','<START>','<END>'] # Set the special tokens
START_ID = 40481
END_ID = 40482
CON_START_ID = 40480
POS_ID = 40478
NEG_ID = 40479
SPECIAL_TOKENS_IDS = [
    POS_ID, NEG_ID, CON_START_ID,
    START_ID, END_ID,
]
NUM_LABELS = 2


def remove_zero_prefix(t: torch.Tensor):
    t_copy = t.gather(
        1,
        (t == 0.0).to(dtype=torch.uint8).sort(dim=1, stable=True)[1]
    )
    return t_copy


def decode_state(state: torch.Tensor, tokenizer) -> str:
    # logger.info(f'decode_state: {state.shape = }, {state = }')
    return tokenizer.decode(
        torch.where(state < 0, 0, state)[0].tolist(),
    ).replace('<unk>', '')


def greedy_search(input_ids, model, length=5, reached_end=False):
    logger.debug(f'{input_ids = }')
    if length == 0:
        return input_ids

    if not reached_end:
        predictions = model(input_ids)

        # Get the predicted next sub-word (here we use top-k search)
        logits = predictions[0, -1, :]
        token_id = torch.argmax(logits).unsqueeze(0)
        # logger.debug(f'{token_id.shape = }, {token_id = }')

        reached_end = token_id.item() == END_ID

    else:
        # logger.debug('Already reached <END> token')
        token_id = torch.tensor([0], device=input_ids.device)

    # Add the predicted token to the list of input ids
    new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)

    # Recursive call
    input_ids = greedy_search(new_input_ids, model, length-1, reached_end=reached_end)

    return input_ids


class LLMPutter(nn.Module):
    def __init__(self,
                 llm: OpenAIGPTLMHeadModel,
                 tokenizer: OpenAIGPTTokenizer) -> None:
        super().__init__()

        self.llm = llm
        self.tokenizer = tokenizer

    def forward(self, state: torch.Tensor, concept: Optional[torch.Tensor] = None) -> torch.Tensor:
        # concept is 0 or 1
        # state is encoded tokens i.e. list of indices
        logger.debug(f'{state = }')
        state_copy = remove_zero_prefix(state)
        logger.debug(f'{state_copy = }')

        if concept is not None:
            state_w_concept = self.apply_concept_to_state(state, concept)
        else:
            # State already has concept incorporated in it
            state_w_concept = state_copy

        lm_logits = self.llm(state_w_concept)
        logger.debug(f'{lm_logits.shape = }, {lm_logits = }')
        return lm_logits

    def put(self, state: torch.Tensor, concept: torch.Tensor) -> torch.Tensor:
        lm_logits = self(state, concept)
        lm_preds = torch.softmax(lm_logits, dim=2)
        # logger.debug(f'{lm_preds.shape = }, {lm_preds = }')
        lm_preds = lm_preds.argmax(dim=2)
        logger.debug(f'{lm_preds.shape = }, {lm_preds = }')
        return lm_preds

    def decode_state(self, state: torch.Tensor) -> str:
        return decode_state(state, self.tokenizer)

    def apply_concept_to_state(self, state: torch.Tensor, concept: torch.Tensor) -> torch.Tensor:
        # Apply concept
        device = state.device
        bs = state.shape[0]
        logger.debug(f'{concept = }')

        concept_addend = NEG_ID * torch.ones_like(concept, dtype=torch.int64) - concept
        # logger.debug(f'{concept_addend.shape = }, {concept_addend = }')
        indices = torch.tensor([[i, 0] for i in range(bs)], device=device)

        return state.index_put(
            tuple(indices.t()),
            concept_addend
        )

    def generate(self, state: torch.Tensor, concept: torch.Tensor,
                 strategy: str = 'greedy',
                 prepare_for_getter: bool = False,
                 prepare_for_putter: bool = False,
                 valid_state: Optional[torch.Tensor] = None) -> torch.Tensor:
        generated_states = []
        state_w_concept = self.apply_concept_to_state(state, concept)
        decoded_states = []
        device = state.device
        seq_len = state_w_concept.shape[1]
        # logger.debug(f'{seq_len = }')
        for i, input_state in enumerate(state_w_concept):
            mask = torch.nonzero(torch.isin(
                input_state,
                torch.tensor([START_ID], dtype=torch.int64, device=device)
            ).to(torch.int64))
            # logger.debug(f'{mask = }')
            mask_len = mask.item()
            mask = [1] * (mask_len + 1) + [0] * (seq_len - mask_len - 1)
            mask = torch.tensor(mask, device=device)
            input_state *= mask
            prompt_encoded = input_state.unsqueeze(0)
            # logger.debug(f'{prompt_encoded = }')
            prompt = self.decode_state(prompt_encoded)
            if strategy == 'greedy':
                raise NotImplemented(f'Strategy {strategy = } is not implemented.')
                # generated_state = greedy_search(prompt_encoded, self, length=state.shape[1])
            elif strategy == 'beam':
                ref, prompt, _ = prepare_prompt(input_state, self.tokenizer, input_state[0].item() == POS_ID)
                # logger.debug(f'generate: {ref = } @@@ {prompt = }')
                generated_state = prediction_with_beam_search(
                    prompt, self.llm, self.tokenizer, device,
                    beam_width=1,
                    vocab_length=max(self.tokenizer.special_tokens.values()) + 1,
                    decode=False
                )[0]
                generated_state += [0] * (seq_len - len(generated_state))
                generated_state = torch.tensor(generated_state[:seq_len], dtype=torch.int64, device=device)
                # logger.debug(f'{generated_state.shape = }, {generated_state = }')

            if prepare_for_putter:
                # add special tokens and valid sentence
                assert valid_state is not None
                vs = valid_state[i, :]
                max_len = vs.shape[0]
                logger.debug(f'{max_len = }')
                logger.debug(f'{vs.shape = }, {vs = }')
                out_for_putter = self.decode_state(generated_state.unsqueeze(0))
                out_for_putter = out_for_putter.split(' ')
                valid_state_decoded = self.decode_state(vs.unsqueeze(0))
                prompt_split = prompt.split(' ')
                out_for_putter = prompt_split[:2] + out_for_putter + prompt_split[-1:]
                logger.debug(f'{out_for_putter = }, {prompt_split = }, {valid_state_decoded = }')
                out_for_putter = ' '.join(out_for_putter)# + ' ' + valid_state_decoded

                # logger.debug(f'{out_for_putter = }')
                out_tokens = self.tokenizer.tokenize(out_for_putter)
                generated_state = self.tokenizer.convert_tokens_to_ids(out_tokens)
                vs_list = [idx for idx in vs.tolist() if idx != -1]
                if len(generated_state) + len(vs_list) > max_len:
                    generated_state_end = max_len - len(vs_list)
                    generated_state = (generated_state[:generated_state_end - 1] +
                                       [START_ID] +
                                       vs_list)
                    assert len(generated_state) == max_len
                logger.debug(f'{len(generated_state) = }, {generated_state = }')

                generated_state = torch.tensor(generated_state, dtype=torch.int64, device=device)

                # Ensure same length for the generated states

                new_gs = torch.zeros(max_len, dtype=torch.long, device=device)
                generated_state = new_gs.put(
                    torch.range(0, generated_state.shape[0] - 1, dtype=torch.long, device=device),
                    generated_state
                )

            if prepare_for_getter:
                # append end token
                gs_list = generated_state.tolist()
                try:
                    end_position = gs_list.index(0)
                except ValueError:
                    end_position = len(gs_list) - 1
                gs_list[end_position] = END_ID
                generated_state = torch.tensor(gs_list, dtype=torch.long, device=device)

            generated_states.append(generated_state.squeeze(0))
            decoded_state = self.decode_state(generated_state.unsqueeze(0))
            out = f'{prompt = } @@@ {decoded_state = }'
            decoded_states.append(out)
            # logger.debug(f'{generated_state = }\n')
            # logger.debug(f'{out = }')

        # for decoded in decoded_states:
        #     logger.debug(decoded)

        # logger.debug(f'{generated_states = }')
        # logger.debug(f'generated_states lengths: {[gs.shape for gs in generated_states]}')

        return torch.stack(generated_states)


class LLMGetter(nn.Module):
    def __init__(self,
                 llm: OpenAIGPTForSequenceClassification,
                 tokenizer: OpenAIGPTTokenizer) -> None:
        super().__init__()

        self.llm = llm
        self.tokenizer = tokenizer

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        state_copy = torch.where(state < 0, 0, state)
        # Remove special tokens
        state_copy = torch.where(
            torch.isin(
                state_copy,
                torch.tensor([POS_ID, NEG_ID]).to(state.device)
            ),
            0,
            state_copy
        )
        # Remove zeros at front
        state_copy = remove_zero_prefix(state_copy)
        logger.debug(f'{state_copy = }')
        logger.debug(f'{state_copy.shape = }, {state_copy.device = }')
        logits = self.llm(state_copy).logits
        return logits

    def get(self, state: torch.Tensor) -> torch.Tensor:
        # state is just the encoded tokens for the
        # sentence you want to classify
        logits = self(state)
        preds = torch.softmax(logits, dim=1).argmax(dim=1)
        logger.debug(f'{preds.shape = }, {preds = }')
        return preds

    def decode_state(self, state: torch.Tensor) -> str:
        return decode_state(state, self.tokenizer)


def prediction_with_beam_search(ref_text, model, tokenizer, device,
                                beam_width=3,
                                vocab_length=40483,
                                decode=True):
    """
    This function decodes sentences using Beam Seach.
    It will output #sentences = beam_width. This function works on a single example.

    ref_text : string : Input sentence
    beam_width : int : Width of the output beam
    vocab_length : int : Size of the Vocab after adding the special tokens
    """

    done = [False for i in range(beam_width)] # To track which beams are already decoded
    stop_decode = False
    out_sentences=[] # List of decoded sentences at any given time

    sm = torch.nn.Softmax(dim=-1) # To calculate Softmax over the final layer Logits
    tokens = tokenizer.tokenize(ref_text) # Tokenize the input text

    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens) # Convert tokens to ids
    index_tokens = [indexed_tokens for i in range(beam_width)] # Replication of Input ids for all the beams

    #index_tokens = [indexed_tokens for i in range(beam_width)]
    torch_tensor = torch.tensor(index_tokens).to(device)
    beam_indexes = [[] for i in range(beam_width)] # indexes of the current decoded beams
    best_scoes = [0 for i in range(beam_width)] # A list of lists to store Probability values of each decoded token of best beams
    count = 0
    while count < model.config.n_positions and not stop_decode:
        if count == 0: # For the first step when only one sentence is availabe
            with torch.no_grad():
                # Calculate output probability distribution over the Vocab,
                preds = sm(model(torch_tensor)) #  shape = [beam_bidth, len(input_sen)+1,Vocab_length]
            top_v, top_i = preds[:,-1,:].topk(beam_width) # Fatch top indexes and it's values
            [beam_indexes[i].append(top_i[0][i].tolist()) for i in range(beam_width)] # Update the Beam indexes
            # Update the best_scores, for first time just add the topk values directly
            for i in range(beam_width):
                best_scoes[i] = top_v[0][i].item()
            count += 1
        else: # After first step
            # Prepare the current_state by concating original input and decoded beam indexes
            current_state = torch.cat((torch_tensor, torch.tensor(beam_indexes).to(device)), dim=1)
            # print(f'{current_state.shape = }')
            if current_state.shape[1] > model.config.n_positions:
                break
            # Prediction on the current state
            with torch.no_grad():
                preds = sm(model(current_state))
            # Multiply new probability predictions with corresponding best scores
            # Total socres = beam_width * Vocab_Size
            flatten_score = (preds[:,-1,:]*torch.tensor(best_scoes).to(device).unsqueeze(1)).view(-1)
            # Fatch the top scores and indexes
            vals, inx = flatten_score.topk(beam_width)
            # print(f'{vals = }, {inx = }')
            # best_score_inx saves the index of best beams after multiplying the probability of new prediction
            best_scoes_inx = (inx // vocab_length).tolist()
            best_scoes = vals.tolist()
            # Unflatten the index
            correct_inx = (inx % vocab_length).tolist()
            # print(f'{best_scoes_inx = }')
            # print(f'{best_scoes = }')
            # print(f'{correct_inx = }')

            # Check if done for all the Beams
            for i in range(beam_width):
                if correct_inx[i] == tokenizer.special_tokens["<END>"]:
                    done[i] = True
            # Update the best score for each the current Beams
            for i in range(beam_width):
                if not done[i]:
                    best_scoes[i] = vals.tolist()[i]
            # Check is All the Beams are Done
            # print(f'{done = }')
            if (sum(done) == beam_width):
                stop_decode = True
            # Prepapre the new beams
            temp_lt=[0 for i in range(beam_width)]
            for i,x in enumerate(best_scoes_inx):
                # print(f'{i = }, {x = }')
                temp_lt[i] = beam_indexes[x] + [correct_inx[i]]
            # Update the Beam indexes
            beam_indexes = temp_lt
            del temp_lt
            count += 1
    # Decode All the beam indexes to till <END> token only and convert into sentence
    for i in range(beam_width):
        try:
            end_index = beam_indexes[i].index(tokenizer.special_tokens["<END>"])
        except ValueError:
            end_index = len(beam_indexes[i])

        if decode:
            out_sentences.append(tokenizer.decode(beam_indexes[i][:end_index]))
        else:
            out_sentences.append(beam_indexes[i][:end_index])

    return out_sentences


def get_best_sentence(input_sentences,
                      model_cls,
                      tokenizer_cls,
                      device,
                      max_seq_len=70,
                      sentiment=1):
    """
    This function selects the sentence from the Beam of the sentences,
    based on the classification probability score.

    input_sentences : list of strings : Sentences generated by the Beam search decoding
    sentiment: int : Expected sentiment (in general class for the classification)
    """
    # BERT pre-processing
    sm = torch.nn.Softmax(dim=-1)

    ids = []
    # segment_ids = []
    # input_masks = []
    pred_lt = []
    for sen in input_sentences:
        text_tokens = tokenizer_cls.tokenize(sen)
        temp_ids = tokenizer_cls.convert_tokens_to_ids(text_tokens[:max_seq_len])
        # input_mask = [1] * len(temp_ids)
        # segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))

        temp_ids += padding
        # input_mask += padding
        # segment_id += padding

        # print(f'{len(temp_ids) = }')
        ids.append(temp_ids[:max_seq_len])
        # input_masks.append(input_mask[:max_seq_len])
        # segment_ids.append(segment_id[:max_seq_len])


    ids = torch.tensor(ids).to(device)
    # segment_ids = torch.tensor(segment_ids).to(device)
    # input_masks = torch.tensor(input_masks).to(device)
    # prediction
    with torch.no_grad():
        preds = sm(model_cls.forward(ids))  #, segment_ids, input_masks))

    preds = preds.tolist()
    inx, inx_val = None, 0
    for i in range(len(input_sentences)):
        temp = preds[i][sentiment]
        if temp > inx_val:
            inx = i
            inx_val = temp
    return input_sentences[inx]


def tokenize_and_encode(file_path, tokenizer):
    '''
    This method tokenizes the input data and encodes it using the OpenAIGPTTokenizer
    :param file_path: Path of the input file, dtype: str
    :return: encoded dataset  dtype: list
    '''
    with open(file_path, 'r') as in_fp:
        lines = in_fp.read().splitlines()

    tokenized_dataset = lines
    for i, line in enumerate(tqdm(lines)):
        line, label = line.split('<END>,')
        line += '<END>'
        label = int(label)
        token = tokenizer.tokenize(line)[:512]
        tokenized_dataset[i] = (tokenizer.convert_tokens_to_ids(token), label)
    return tokenized_dataset


def pre_process_dataset(encoded_dataset, input_length):
    """
    This method is to create torch tensor of input ids and lm labels
    :param encoded_dataset: Input dataset, dtype: list
    :param input_length: Maximum length of sentence from training and eval dataset, dtype: int
    :return: torch.tensor of size [len(encoded_dataset), 2]
    """

    n_batch = len(encoded_dataset)
    input_ids = np.zeros(shape=(n_batch, input_length), dtype=np.int64)
    full_input_ids = np.zeros(shape=(n_batch, input_length), dtype=np.int64)
    lm_labels = np.full(shape=(n_batch, input_length), fill_value=-1, dtype=np.int64)
    labels = np.ones(shape=(n_batch,), dtype=np.int64)

    logger.debug(f'{input_ids.shape = }, {lm_labels.shape = }, {labels.shape = }')
    logger.debug(f'{START_ID = }')
    for i, (tokens, label) in enumerate(encoded_dataset):
        if i == 0: logger.debug(f'{i = }, {tokens = }, {label = }')
        try:
            #tokens = tokens[:input_length]
            start_id_index = tokens.index(START_ID)
            con_start_id_index = tokens.index(CON_START_ID)
            input_ids[i, :start_id_index+1] = tokens[:start_id_index+1]
            full_input_ids[i, :len(tokens)] = tokens
            # input_ids[i, :con_start_id_index] = [0] * con_start_id_index
            # input_ids[i, con_start_id_index - 1] = POS_ID if label else NEG_ID
            lm_labels[i, start_id_index:len(tokens)-1] = tokens[start_id_index+1:len(tokens)]
            labels[i] = label
            # LM loss calculate only for tokens after <START> token in the sentence
            #lm_labels[i, :len(tokens)-1] = tokens[1:]
        except ValueError as ve:
            logger.info("Index {} doesn't have start token".format(i))

    input_ids = torch.tensor(input_ids)
    full_input_ids = torch.tensor(full_input_ids)
    lm_labels = torch.tensor(lm_labels)
    labels = torch.tensor(labels)
    tensor_dataset = (input_ids, full_input_ids, lm_labels, labels)
    #tensor_dataset.append(torch.tensor(d) for d in all_inputs)

    return tensor_dataset


def prepare_prompt(encoded_text, tokenizer, sentiment):
    ref = tokenizer.decode(encoded_text.tolist()).replace('<unk>', '')
    logger.debug(f'{ref = }')
    ref_split = ref.split(' ')
    start_index = ref_split.index('<START>')
    prompt = ' '.join(ref_split[:start_index+1])
    valid = ' '.join(ref_split[start_index+1:])
    if sentiment and '<NEG>' in prompt:
        prompt = prompt.replace('<NEG>', '<POS>')
    elif not sentiment and '<POS>' in prompt:
        prompt = prompt.replace('<POS>', '<NEG>')
    prompt = prompt.strip()
    logger.debug(f'{prompt = }')

    return ref, prompt, valid


def get_sample_prediction(putter, getter, tokenizer, device,
                          encoded_text, sentiment,
                          beam_width=5,
                          max_seq_len=70):
    ref, prompt, _ = prepare_prompt(encoded_text, tokenizer, sentiment)

    out_sentences = prediction_with_beam_search(
        prompt, putter.llm, tokenizer, device,
        beam_width=beam_width,
        vocab_length=max(tokenizer.special_tokens.values()) + 1
    )
    best_sentence = get_best_sentence(
        out_sentences, getter,
        tokenizer, device,
        max_seq_len=max_seq_len,
        sentiment=sentiment
    )
    description = f'{ref = } @@@ {prompt = } @@@ {best_sentence = }'

    return description


def get_grad_norm(m):
    return torch.sqrt(sum(
        [torch.norm(p.grad)**2 for p in m.parameters()
         if p is not None and p.grad is not None]
    ))


def set_seeds(seed):
    # set seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# training parameters
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_putter_path", default=None, type=str, required=True,
                        help="The path where the pretrained putter model was saved.")
    parser.add_argument("--pretrained_getter_path", default=None, type=str, required=True,
                        help="The path where the pretrained putter model was saved.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument('--neptune-config-file',
                        default=None,
                        type=str,
                        help='Path to neptune config JSON file. '
                             'Script will not log to Neptune if not set.')
    parser.add_argument('--train_dataset', type=str, default='')
    parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument('--eval_dataset', type=str, default='')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--rule_weights', type=json.loads, required=True)
    parser.add_argument('--freeze_getter', type=bool, default=False)
    parser.add_argument('--train_start_epoch', type=int, default=0)
    parser.add_argument('--num_train_epochs', type=int, default=1)
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--logging_step', type=int, default=1000)
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument('--warmup_proportion', type=float, default=0.002)
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--max_seq_length', type=int, default=70)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--log_level', type=int, default=logging.INFO)
    parser.add_argument('--is_debug_run', type=bool, default=False)

    args = parser.parse_args()
    logger.setLevel(args.log_level)
    logger.info(f'{args = }')
    log_run = False

    if args.neptune_config_file is not None:
        log_run = True
        run_config = json.load(open(args.neptune_config_file))

        logger.info('Starting Neptune run.')
        run_config['name'] = 'tdrg-finetuning'
        run_config['tags'] = ['llm', 'tdrg', 'fine-tuning']
        if args.is_debug_run:
            run_config['tags'].append('debug')

        neptune_logger = neptune.init_run(**run_config)

    # device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    CLIP_GRADIENTS = 1
    LABEL_NOISINESS = 0.5
    NUM_WORKERS = 8

    set_seeds(args.seed)

    # set rule weights
    rule_weights = {k: float(v) for k, v in args.rule_weights.items()}
    tdrg_loss_weight = rule_weights.pop('tdrg', None)

    # setup system
    lens_world = System()
    if 'classify' in rule_weights and not args.freeze_getter:
        lens_world.add_rule('classify',
                            classify,
                            torch.nn.CrossEntropyLoss(),
                            getter='getter',
                            state='valid_state',
                            value='value')
    if 'get_put' in rule_weights:
        lens_world.add_rule('get_put',
                            get_put,
                            torch.nn.CrossEntropyLoss(ignore_index=-1),
                            putter='putter',
                            getter='getter',
                            state='full_state',
                            valid_state='valid_state')
    if 'put_get' in rule_weights:
        lens_world.add_rule('put_get',
                            put_get,
                            torch.nn.CrossEntropyLoss(),
                            putter='putter',
                            getter='getter',
                            state='state',
                            value='value')
    if 'put_put' in rule_weights:
        lens_world.add_rule('put_put',
                            put_put,
                            torch.nn.CrossEntropyLoss(ignore_index=-1),
                            putter='putter',
                            state='state',
                            value1='random_value',
                            value2='value',
                            valid_state='valid_state')
    if 'undo' in rule_weights:
        lens_world.add_rule('undo',
                            undo,
                            torch.nn.CrossEntropyLoss(ignore_index=-1),
                            putter='putter',
                            getter='getter',
                            state='state',
                            value='value',
                            valid_state='valid_state')

    rule_calc_accuracies = {
        'classify': True,
        'get_put': False,
        'put_get': True,
        'put_put': False,
        'undo': False,
    }
    rule_calc_accuracies = {k: rule_calc_accuracies[k] for k in rule_weights.keys()}

    device_name, device_num = args.device.split(":")
    fabric = Fabric(accelerator=device_name, devices=[int(device_num)], precision="16-mixed")
    torch.set_float32_matmul_precision("high")
    logger.info(f'{fabric.device = }')

    # getter - use pretrained classifier
    # putter - load trained model from other repo
    tokenizer = OpenAIGPTTokenizer.from_pretrained(
        'openai-gpt', special_tokens=SPECIAL_TOKENS,
    )
    logger.info(f'{len(tokenizer) = }')
    getter = OpenAIGPTForSequenceClassification.from_pretrained(
        'openai-gpt',
        num_labels=NUM_LABELS,
    )
    getter.resize_token_embeddings(len(tokenizer))

    getter_state_path = args.pretrained_getter_path
    putter_state_path = args.pretrained_putter_path
    if args.train_start_epoch:
        getter_state_path = os.path.join(args.output_dir, f'getter-chkpt-{args.train_start_epoch-1}.pth')
        putter_state_path = os.path.join(args.output_dir, f'putter-chkpt-{args.train_start_epoch-1}.pth')
        logger.info(f'Resuming training from epoch {args.train_start_epoch}')
        logger.info(f'Loading weights from {getter_state_path}, {putter_state_path}')

    getter_state_dict = torch.load(
        getter_state_path, map_location=fabric.device
    )
    getter.load_state_dict(getter_state_dict)
    putter = OpenAIGPTLMHeadModel.from_pretrained(
        'openai-gpt',
        num_special_tokens=len(SPECIAL_TOKENS)
    )
    putter_state_dict = torch.load(
        putter_state_path, map_location=fabric.device
    )
    putter.load_state_dict(putter_state_dict)
    getter = LLMGetter(getter, tokenizer)
    putter = LLMPutter(putter, tokenizer)

    # prepare input data
    pad_token_id = tokenizer.convert_tokens_to_ids(['<PAD>'])[0]
    logger.debug(f'{pad_token_id = }')
    # Set pad_token_id for getter LLM
    getter.llm.config.pad_token_id = pad_token_id
    start_token_id = tokenizer.convert_tokens_to_ids(['<START>'])[0]
    logger.info("Encoding dataset...")
    train_dataset = tokenize_and_encode(args.train_dataset, tokenizer)
    eval_dataset = tokenize_and_encode(args.eval_dataset, tokenizer)
    logger.info("Training samples = {}".format(len(train_dataset)))
    logger.info("Validation samples = {}".format(len(eval_dataset)))
    logger.info("Example = {}".format(train_dataset[0]))

    # Compute the mex input length for the Transformer
    train_dataset = [(x, l) for x, l in train_dataset if len(x) <= args.max_seq_length and start_token_id in x] # Remove all sentence longer than max_seq_length
    eval_dataset = [(x, l) for x, l in eval_dataset if len(x) <= args.max_seq_length and start_token_id in x]
    input_length = max(max(len(t) for t, _ in train_dataset), max(len(q) for q, _ in eval_dataset))
    input_length = min(input_length, putter.llm.config.n_positions)  # Max size of input for the pre-trained model

    # Prepare input tensors and dataloders
    train_tensor_dataset = pre_process_dataset(train_dataset, input_length)
    eval_tensor_dataset = pre_process_dataset(eval_dataset, input_length)

    logger.info("Training Example Input ids= {}".format(train_tensor_dataset[0][0]))
    logger.info("Training Example Language Modeling ids = {}".format(train_tensor_dataset[1][0]))
    time.sleep(10)
    train_data = TensorDataset(*train_tensor_dataset)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    eval_data = TensorDataset(*eval_tensor_dataset)
    eval_sampler = RandomSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Prepare optimizer
    # models = torch.nn.ModuleList([getter, putter])
    param_optimizer = list(putter.named_parameters())
    if not args.freeze_getter:
        param_optimizer += list(getter.named_parameters())

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size
    optimiser = OpenAIAdam(optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           warmup=args.warmup_proportion,
                           max_grad_norm=args.max_grad_norm,
                           weight_decay=args.weight_decay,
                           t_total=num_train_optimization_steps)

    # TODO: neptune logging
    if log_run:
        neptune_logger['params'] = vars(args)
        neptune_logger['params/rule_weights'] = args.rule_weights

    # training loop
    step = 0
    epoch = 0

    putter, optimiser = fabric.setup(putter, optimiser)
    getter = fabric.setup_module(getter)
    train_dataloader, eval_dataloader = fabric.setup_dataloaders(train_dataloader, eval_dataloader)

    getter.train()
    putter.train()

    if args.freeze_getter:
        logger.info('Freezing getter weights')
        getter.eval()
        getter = NoGradient(getter)

    for epoch in trange(args.train_start_epoch,
                        args.train_start_epoch + args.num_train_epochs,
                        desc='Epoch'):
        getter_grad_norm = 0
        putter_grad_norm = 0

        iters = len(train_dataloader)
        tqdm_bar = tqdm(train_dataloader, desc='Training')
        for step, batch in enumerate(tqdm_bar):
            # input_ids, full_input_ids, lm_labels, labels = [t.to(device) for t in batch]
            input_ids, full_input_ids, lm_labels, labels = batch

            logger.debug(f'+++ {input_ids.shape = }, {input_ids = }')
            logger.debug(f'+++ {full_input_ids.shape = }, {full_input_ids = }')
            logger.debug(f'+++ {lm_labels.shape = }, {lm_labels = }')
            logger.debug(f'+++ {labels.shape = }, {labels = }')

            logger.debug(f'{tokenizer.decode(input_ids[0].tolist())}')
            logger.debug(f'{tokenizer.decode(full_input_ids[0].tolist())}')
            logger.debug(f'{tokenizer.decode(torch.where(lm_labels < 0, 0, lm_labels).to(lm_labels.device)[0].tolist())}')
            state_batch = input_ids
            full_state_batch = full_input_ids
            value_batch = labels
            rand_value_batch = torch.randint(
                0, NUM_LABELS, size=labels.shape
            ).to(fabric.device)

            optimiser.zero_grad()
            lens_results = lens_world(putter=putter,
                                      getter=getter,
                                      state=state_batch,
                                      full_state=full_state_batch,
                                      valid_state=lm_labels,
                                      value=value_batch,
                                      random_value=rand_value_batch,)

            lens_losses = {k: v['loss'] for k, v in lens_results.items()}
            loss = sum([rule_weights[k] * v for k, v in lens_losses.items()])

            tdrg_loss = 0.
            if tdrg_loss_weight is not None:
                tdrg_loss = putter.llm(full_input_ids, lm_labels=lm_labels)
                loss += tdrg_loss_weight * tdrg_loss

            fabric.backward(loss)

            if log_run:
                neptune_logger['train/loss/aggregate'].append(loss)
                neptune_logger['train/lr'].append(optimiser.get_lr()[0])
                # Lens laws
                accuracies = {}
                for k, v in lens_results.items():
                    if rule_calc_accuracies[k]:
                        matches = (lens_results[k]['prediction'].argmax(-1)
                                == lens_results[k]['target'])
                        accuracies[k] = matches.sum().item() / labels.shape[-1]

                for k, v in lens_results.items():
                    neptune_logger[f'train/loss/{k}'].append(v['loss'])
                neptune_logger['train/loss/tdrg'].append(tdrg_loss)

                for k, v in accuracies.items():
                    neptune_logger[f'train/acc/{k}'].append(v)

                getter_grad_norm += get_grad_norm(getter)
                putter_grad_norm += get_grad_norm(putter)

                neptune_logger['train/grad_norm/getter'].append(getter_grad_norm)
                neptune_logger['train/grad_norm/putter'].append(putter_grad_norm)

            optimiser.step()

            if args.do_eval and step % args.logging_step == 0:
                # "Evaluate"
                putter.eval()
                getter.eval()
                logger.info("== logging ===")
                results = dict(lens_results)

                description = f'Epoch {epoch}, Step {step}\nLoss: {loss:g} <- '
                description += ' '.join(f'{k}={v["loss"]:g}' for k, v in results.items())
                description += f' tdrg={tdrg_loss}'
                logger.info(description)

                eval_loss = {k: 0. for k in lens_results.keys()}
                tdrg_eval_loss = 0.
                aggr_eval_loss = 0.
                eval_acc = {k: 0. for k, v in rule_calc_accuracies.items() if v}
                nb_eval_steps, nb_eval_examples = 0, 0
                for eval_batch in tqdm(eval_dataloader, desc='Evaluating'):
                    # input_ids_eval, full_input_ids_eval, lm_labels_eval, labels_eval = [
                    #     t.to(device) for t in eval_batch
                    # ]
                    input_ids_eval, full_input_ids_eval, lm_labels_eval, labels_eval = eval_batch
                    rand_value_batch_eval = torch.randint(
                        0, NUM_LABELS, size=labels_eval.shape
                    ).to(fabric.device)

                    with torch.no_grad():
                        lens_results_eval = lens_world(
                            putter=putter,
                            getter=getter,
                            state=input_ids_eval,
                            full_state=full_input_ids_eval,
                            valid_state=lm_labels_eval,
                            value=labels_eval,
                            random_value=rand_value_batch_eval,
                        )

                        lens_losses_eval = {k: v['loss'] for k, v in lens_results_eval.items()}
                        aggr_eval_loss += sum([rule_weights[k] * v for k, v in lens_losses_eval.items()])
                        if tdrg_loss_weight is not None:
                            tdrg_batch_eval_loss = putter.llm(full_input_ids_eval, lm_labels=lm_labels_eval)
                            tdrg_eval_loss += tdrg_batch_eval_loss
                            aggr_eval_loss += tdrg_loss_weight * tdrg_batch_eval_loss.item()
                    logger.debug(f'{lens_results_eval = }')

                    for k, v in lens_results_eval.items():
                        tmp_eval_loss = lens_results_eval[k]['loss']
                        if rule_calc_accuracies[k]:
                            eval_matches = (lens_results_eval[k]['prediction'].argmax(-1)
                                            == lens_results_eval[k]['target']).sum().item()
                            eval_acc[k] += eval_matches
                        eval_loss[k] += tmp_eval_loss.mean().item()

                    nb_eval_examples += labels_eval.size(0)
                    nb_eval_steps += 1

                for k in eval_loss.keys():
                    eval_loss[k] /= nb_eval_steps
                    if log_run:
                        neptune_logger[f'eval/loss/{k}'].append(eval_loss[k])

                if tdrg_loss_weight is not None:
                    tdrg_eval_loss /= nb_eval_steps
                    tdrg_eval_loss = tdrg_eval_loss.item()
                    if log_run:
                        neptune_logger[f'eval/loss/tdrg'].append(tdrg_eval_loss)

                aggr_eval_loss /= nb_eval_steps
                aggr_eval_loss = aggr_eval_loss.item()
                if log_run:
                    neptune_logger[f'eval/loss/aggregate_loss'].append(aggr_eval_loss)

                for k in eval_acc.keys():
                    eval_acc[k] /= nb_eval_examples
                    if log_run:
                        neptune_logger[f'eval/acc/{k}'].append(eval_acc[k])

                logger.info(f'Eval results: {eval_loss = }, {eval_acc = }, {tdrg_eval_loss = }, {aggr_eval_loss = }')

                with torch.no_grad():
                    train_sample_pred_pos = get_sample_prediction(
                        putter, getter, tokenizer, fabric.device,
                        state_batch[0], 1,
                        beam_width=1,
                        max_seq_len=args.max_seq_length
                    )
                    eval_sample_pred_pos = get_sample_prediction(
                        putter, getter, tokenizer, fabric.device,
                        input_ids_eval[0], 1,
                        beam_width=1,
                        max_seq_len=args.max_seq_length
                    )
                    train_sample_pred_neg = get_sample_prediction(
                        putter, getter, tokenizer, fabric.device,
                        state_batch[0], 0,
                        beam_width=1,
                        max_seq_len=args.max_seq_length
                    )
                    eval_sample_pred_neg = get_sample_prediction(
                        putter, getter, tokenizer, fabric.device,
                        input_ids_eval[0], 0,
                        beam_width=1,
                        max_seq_len=args.max_seq_length
                    )
                    logger.info('Train sample preds:')
                    logger.info(f'{train_sample_pred_pos = }')
                    logger.info(f'{train_sample_pred_neg = }')
                    logger.info('Eval sample pred:')
                    logger.info(f'{eval_sample_pred_pos = }')
                    logger.info(f'{eval_sample_pred_neg = }')

                if log_run:
                    neptune_logger['train/best_sentence/pos'].append(train_sample_pred_pos)
                    neptune_logger['train/best_sentence/neg'].append(train_sample_pred_neg)
                    neptune_logger['eval/best_sentence/pos'].append(eval_sample_pred_pos)
                    neptune_logger['eval/best_sentence/neg'].append(eval_sample_pred_neg)

                putter_pth = os.path.join(args.output_dir, f'putter-chkpt-{epoch}-{step}.pth')
                torch.save(putter.llm.state_dict(), putter_pth)
                putter.train()

                if not args.freeze_getter:
                    getter_pth = os.path.join(args.output_dir, f'getter-chkpt-{epoch}-{step}.pth')
                    torch.save(getter.llm.state_dict(), getter_pth)
                    getter.train()

            step += 1

        # Reset grad norm trackers
        getter_grad_norm = 0
        putter_grad_norm = 0

        # Save model every end of epoch
        putter_pth = os.path.join(args.output_dir, f'putter-chkpt-{epoch}.pth')
        getter_pth = os.path.join(args.output_dir, f'getter-chkpt-{epoch}.pth')
        torch.save(putter.llm.state_dict(), putter_pth)
        torch.save(getter.llm.state_dict(), getter_pth)

        if log_run:
            neptune_logger[f'models/putter-chkpt-{epoch}.pth'].upload(putter_pth)
            neptune_logger[f'models/getter-chkpt-{epoch}.pth'].upload(getter_pth)




if __name__ == '__main__':
    main()
