import argparse
import logging
import json
import os
import random
import sys
import time

import neptune
import numpy as np
import torch
from torch import nn
from torch.utils.data import (DataLoader, RandomSampler,
                              SequentialSampler, TensorDataset,)
from tqdm.auto import tqdm, trange

sys.path.append('.')

from diagram import System
from rules.lens_llm import classify, get_put, put_get, put_put, undo

# Imported from tdrg repo
from tdrg.pytorch_pretrained_bert import (
    OpenAIGPTTokenizer, OpenAIGPTLMHeadModel,
    OpenAIAdam
)
from tdrg.pytorch_pretrained_bert.modeling import (
    BertForSequenceClassification, BertConfig
)
from tdrg.bertviz.bertviz.pytorch_pretrained_bert import BertTokenizer
from transformers import OpenAIGPTForSequenceClassification


logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


SPECIAL_TOKENS = ['<POS>', '<NEG>','<CON_START>','<START>','<END>'] # Set the special tokens
START_ID = 40481
END_ID = 40482
CON_START_ID = 40480
POS_ID = 40478
NEG_ID = 40479
SPECIAL_TOKENS_IDS = [
    POS_ID, NEG_ID, CON_START_ID,
    START_ID, END_ID,
]
NUM_LABELS = 2



class LLMGetter(nn.Module):
    def __init__(self, llm: OpenAIGPTForSequenceClassification) -> None:
        super().__init__()
        self.llm = llm

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        state_copy = torch.where(state < 0, 0, state)
        # Remove special tokens
        state_copy = torch.where(
            torch.isin(
                state_copy,
                torch.tensor([POS_ID, NEG_ID]).to(state.device)
            ),
            0,
            state_copy
        )
        # Remove zeros at front
        state_copy = state_copy.gather(
            1,
            (state_copy == 0.0).to(dtype=torch.uint8).sort(dim=1, stable=True)[1]
        )
        logger.debug(f'{state_copy = }')
        logger.debug(f'{state_copy.shape = }, {state_copy.device = }')
        logits = self.llm(state_copy).logits
        return logits

    def get(self, state: torch.Tensor) -> torch.Tensor:
        # state is just the encoded tokens for the
        # sentence you want to classify
        logits = self(state)
        preds = torch.softmax(logits, dim=1).argmax(dim=1)
        logger.debug(f'{preds.shape = }, {preds = }')
        return preds


def tokenize_and_encode(file_path, tokenizer):
    '''
    This method tokenizes the input data and encodes it using the OpenAIGPTTokenizer
    :param file_path: Path of the input file, dtype: str
    :return: encoded dataset  dtype: list
    '''
    with open(file_path, 'r') as in_fp:
        lines = in_fp.read().splitlines()

    tokenized_dataset = lines
    for i, line in enumerate(tqdm(lines)):
        line, label = line.split('<END>,')
        line += '<END>'
        label = int(label)
        token = tokenizer.tokenize(line)[:512]
        tokenized_dataset[i] = (tokenizer.convert_tokens_to_ids(token), label)
    return tokenized_dataset


def pre_process_dataset(encoded_dataset, input_length, lm_pad_token_id=-1):
    """
    This method is to create torch tensor of input ids and lm labels
    :param encoded_dataset: Input dataset, dtype: list
    :param input_length: Maximum length of sentence from training and eval dataset, dtype: int
    :return: torch.tensor of size [len(encoded_dataset), 2]
    """

    n_batch = len(encoded_dataset)
    input_ids = np.zeros(shape=(n_batch, input_length), dtype=np.int64)
    lm_labels = np.full(
        shape=(n_batch, input_length),
        fill_value=lm_pad_token_id,
        dtype=np.int64
    )
    labels = np.ones(shape=(n_batch,), dtype=np.int64)

    logger.debug(f'{input_ids.shape = }, {lm_labels.shape = }, {labels.shape = }')
    logger.debug(f'{START_ID = }')
    for i, (tokens, label) in enumerate(encoded_dataset):
        if i == 0: logger.debug(f'{i = }, {tokens = }, {label = }')
        try:
            #tokens = tokens[:input_length]
            start_id_index = tokens.index(START_ID)
            con_start_id_index = tokens.index(CON_START_ID)
            input_ids[i, :start_id_index+1] = tokens[:start_id_index+1]
            input_ids[i, :con_start_id_index] = [0] * con_start_id_index
            input_ids[i, con_start_id_index - 1] = POS_ID if label else NEG_ID
            lm_labels[i, start_id_index:len(tokens)-1] = tokens[start_id_index+1:len(tokens)]
            labels[i] = label
            # LM loss calculate only for tokens after <START> token in the sentence
            #lm_labels[i, :len(tokens)-1] = tokens[1:]
        except ValueError as ve:
            logger.info("Index {} doesn't have start token".format(i))

    input_ids = torch.tensor(input_ids)
    lm_labels = torch.tensor(lm_labels)
    labels = torch.tensor(labels)
    tensor_dataset = (input_ids, lm_labels, labels)
    #tensor_dataset.append(torch.tensor(d) for d in all_inputs)

    return tensor_dataset


def get_sample_prediction(getter, tokenizer, encoded_text, true_label):
    decoded_text = tokenizer.decode(encoded_text.tolist()).replace('<unk>', '')
    pred = getter.get(encoded_text.view(1, -1)).item()

    description = f'{decoded_text = } | {true_label = } | {pred = }'

    return description


def get_grad_norm(m):
    return torch.sqrt(sum(
        [torch.norm(p.grad)**2 for p in m.parameters()
         if p is not None and p.grad is not None]
    ))


def set_seeds(seed):
    # set seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# training parameters
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument('--neptune-config-file',
                        default=None,
                        type=str,
                        help='Path to neptune config JSON file. '
                             'Script will not log to Neptune if not set.')
    parser.add_argument('--train_dataset', type=str, default='')
    parser.add_argument('--eval_dataset', type=str, default='')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_train_epochs', type=int, default=1)
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--logging_step', type=int, default=1000)
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument('--warmup_proportion', type=float, default=0.002)
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--max_seq_length', type=int, default=70)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--log_level', type=int, default=logging.INFO)

    args = parser.parse_args()
    logger.setLevel(args.log_level)
    logger.info(f'{args = }')
    log_run = False

    if args.neptune_config_file is not None:
        log_run = True
        run_config = json.load(open(args.neptune_config_file))

        logger.info('Starting Neptune run.')
        run_config['name'] = 'tdrg-finetuning'
        run_config['tags'] = ['llm', 'tdrg', 'fine-tuning']

        neptune_logger = neptune.init_run(**run_config)

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    CLIP_GRADIENTS = 1
    LABEL_NOISINESS = 0.5
    NUM_WORKERS = 8

    set_seeds(args.seed)

    # setup system
    lens_world = System()
    lens_world.add_rule('classify',
                        classify,
                        torch.nn.CrossEntropyLoss(),
                        getter='getter',
                        state='valid_state',
                        value='value')

    # set rule weights
    rule_weights = {
        'classify': 1, # increase this
    }

    rule_calc_accuracies = {
        'classify': True,
    }

    # getter - use pretrained classifier?
    # putter - load trained model from other repo
    getter = OpenAIGPTForSequenceClassification.from_pretrained(
        'openai-gpt',
        num_labels=NUM_LABELS,
    ).to(device)
    getter = LLMGetter(getter)
    getter.to(device).train()

    # prepare input data
    tokenizer = OpenAIGPTTokenizer.from_pretrained(
        'openai-gpt', special_tokens=SPECIAL_TOKENS,
    )
    logger.info(f'{len(tokenizer) = }')
    getter.llm.resize_token_embeddings(len(tokenizer))
    pad_token_id = tokenizer.convert_tokens_to_ids(['<PAD>'])[0]
    logger.debug(f'{pad_token_id = }')
    # Set pad_token_id for getter LLM
    getter.llm.config.pad_token_id = pad_token_id
    start_token_id = tokenizer.convert_tokens_to_ids(['<START>'])[0]
    logger.info("Encoding dataset...")
    train_dataset = tokenize_and_encode(args.train_dataset, tokenizer)
    eval_dataset = tokenize_and_encode(args.eval_dataset, tokenizer)
    logger.info("Training samples = {}".format(len(train_dataset)))
    logger.info("Validation samples = {}".format(len(eval_dataset)))
    logger.info("Example = {}".format(train_dataset[0]))

    # Compute the mex input length for the Transformer
    train_dataset = [(x, l) for x, l in train_dataset if len(x) <= args.max_seq_length and start_token_id in x] # Remove all sentence longer than max_seq_length
    eval_dataset = [(x, l) for x, l in eval_dataset if len(x) <= args.max_seq_length and start_token_id in x]
    input_length = max(max(len(t) for t, _ in train_dataset), max(len(q) for q, _ in eval_dataset))
    input_length = min(input_length, getter.llm.config.n_positions)  # Max size of input for the pre-trained model

    # Prepare input tensors and dataloders
    train_tensor_dataset = pre_process_dataset(train_dataset, input_length,
                                               lm_pad_token_id=pad_token_id)
    eval_tensor_dataset = pre_process_dataset(eval_dataset, input_length,
                                              lm_pad_token_id=pad_token_id)

    logger.info("Training Example Input ids= {}".format(train_tensor_dataset[0][0]))
    logger.info("Training Example Language Modeling ids = {}".format(train_tensor_dataset[1][0]))
    time.sleep(10)
    train_data = TensorDataset(*train_tensor_dataset)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    eval_data = TensorDataset(*eval_tensor_dataset)
    eval_sampler = RandomSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Prepare optimizer
    models = torch.nn.ModuleList([getter])
    param_optimizer = list(getter.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size
    optimiser = OpenAIAdam(optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           warmup=args.warmup_proportion,
                           max_grad_norm=args.max_grad_norm,
                           weight_decay=args.weight_decay,
                           t_total=num_train_optimization_steps)


    # TODO: neptune logging
    if log_run:
        neptune_logger['params'] = vars(args)
        neptune_logger['params/rule_weights'] = rule_weights

    # training loop
    step = 0
    epoch = 0

    for epoch in trange(int(args.num_train_epochs), desc='Epoch'):
        getter_grad_norm = 0

        iters = len(train_dataloader)
        tqdm_bar = tqdm(train_dataloader, desc='Training')
        for step, batch in enumerate(tqdm_bar):
            logger.info(f'{step = }')
            _, lm_labels, labels = [t.to(device) for t in batch]
            logger.debug(f'+++ {lm_labels.shape = }, {lm_labels = }')
            logger.debug(f'+++ {labels.shape = }, {labels = }')
            value_batch = labels

            optimiser.zero_grad()
            # Treat our entire lens-based ensemble as our generator
            # We essentially have a generator here.
            lens_results = lens_world(getter=getter,
                                      valid_state=lm_labels,
                                      value=value_batch,)

            lens_losses = {k: v['loss'] for k, v in lens_results.items()}
            loss = sum([rule_weights[k] * v for k, v in lens_losses.items()])

            loss.backward()

            if log_run:
                neptune_logger['train/lr'].append(optimiser.get_lr()[0])
                # Lens laws
                accuracies = {}
                for k, v in lens_results.items():
                    if rule_calc_accuracies[k]:
                        matches = (lens_results[k]['prediction'].argmax(-1)
                                == lens_results[k]['target'])
                        accuracies[k] = matches.sum().item() / labels.shape[-1]

                for k, v in lens_results.items():
                    neptune_logger[f'train/loss/{k}'].append(v['loss'])

                for k, v in accuracies.items():
                    neptune_logger[f'train/acc/{k}'].append(v)

                getter_grad_norm += get_grad_norm(getter)
                neptune_logger['train/grad_norm/getter'].append(getter_grad_norm)

            optimiser.step()

            step += 1
            if step % args.logging_step == 0:
                # "Evaluate"
                getter.eval()
                logger.info("== logging ===")
                results = dict(lens_results)

                description = f'Epoch {epoch}, Step {step}\nLoss: {loss:g} <- '
                description += ' '.join(f'{k}={v["loss"]:g}' for k, v in results.items())
                logger.info(description)

                eval_loss, eval_accuracy = 0., 0.
                nb_eval_steps, nb_eval_examples = 0, 0
                for eval_batch in tqdm(eval_dataloader, desc='Evaluating'):
                    _, lm_labels_eval, labels_eval = [t.to(device) for t in eval_batch]

                    with torch.no_grad():
                        lens_results_eval = lens_world(
                            getter=getter,
                            valid_state=lm_labels_eval,
                            value=labels_eval,
                        )
                    logger.debug(f'{lens_results_eval = }')
                    tmp_eval_loss = lens_results_eval['classify']['loss']
                    eval_matches = (lens_results_eval['classify']['prediction'].argmax(-1)
                            == lens_results_eval['classify']['target']).sum().item()
                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += eval_matches
                    nb_eval_examples += labels_eval.size(0)
                    nb_eval_steps += 1

                eval_loss /= nb_eval_steps
                eval_accuracy /= nb_eval_examples
                logger.info(f'Eval results: {eval_loss = }, {eval_accuracy = }')
                if log_run:
                    neptune_logger['eval/loss/classify'].append(eval_loss)
                    neptune_logger['eval/acc/classify'].append(eval_accuracy)

                with torch.no_grad():
                    train_sample_pred = get_sample_prediction(
                        getter, tokenizer,
                        lm_labels[0], value_batch[0].item()
                    )
                    eval_sample_pred = get_sample_prediction(
                        getter, tokenizer,
                        lm_labels_eval[0], labels_eval[0].item()
                    )

                    logger.info(f'Train sample pred: {train_sample_pred}')
                    logger.info(f'Eval sample pred: {eval_sample_pred}')

                if log_run:
                    neptune_logger['train/sample_prediction'].append(train_sample_pred)
                    neptune_logger['eval/sample_prediction'].append(eval_sample_pred)

                getter.train()

        # Reset grad norm trackers
        getter_grad_norm = 0

        # Save model every end of epoch
        getter_pth = os.path.join(args.output_dir, f'getter-chkpt-{epoch}.pth')

        torch.save(getter.llm.state_dict(), getter_pth)



if __name__ == '__main__':
    main()
