# This patch is required for our older Unsloth - https://github.com/unslothai/unsloth/issues/963
import unsloth
import os
import torch

import numpy as np
from copy import deepcopy
from functools import partial
from collections import Counter

from transformers import IntervalStrategy, GenerationConfig, Trainer, TrainingArguments
from tqdm import trange
from unsloth import FastLanguageModel, is_bfloat16_supported

from utils import read_cli, get_config, override_config, load_data, load_data_bayesnet, load_data_rag
from bayes_net import OneDiseaseBayesNet
from prompts import prompt_factory
from functools import partial

try:
    import wandb
except:
    wandb = None


def collate_fn(batch, pad_token_id=0):
    max_input_len = max([len(x['input_ids']) for x in batch])

    input_ids = np.ones((len(batch), max_input_len), dtype=np.int64) * pad_token_id
    labels = np.ones((len(batch), max_input_len), dtype=np.int64) * -100
    attn_mask = np.zeros((len(batch), max_input_len), dtype=np.int64)

    for ii, entry in enumerate(batch):
        input_ids[ii, :len(entry['input_ids'])] = entry['input_ids']
        labels[ii, :len(entry['labels'])] = entry['labels']
        attn_mask[ii, :len(entry['input_ids'])] = 1

    ret = {
        'input_ids': torch.tensor(input_ids),
        'labels': torch.tensor(labels),
        'attention_mask': torch.tensor(attn_mask)
    }
    return ret


def format_func(sample, tokenizer):
    # Tokenizer should input add_special_tokens=False
    inputs = sample['prompt_input']
    outputs = sample['prompt_output']
    combined = inputs + outputs

    gen_prompt = tokenizer.apply_chat_template(inputs, tokenize=False, add_generation_prompt=True)
    comb_prompt = tokenizer.apply_chat_template(combined, tokenize=False, add_generation_prompt=False)

    gen_ids = tokenizer(gen_prompt, add_special_tokens=False)['input_ids']
    comb_ids = tokenizer(comb_prompt, add_special_tokens=False)['input_ids']

    for ii in range(len(gen_ids)):
        assert gen_ids[ii] == comb_ids[ii]

    sample['input_ids'] = comb_ids
    sample['attention_mask'] = [1] * len(comb_ids)
    sample['labels'] = [-100 if ii <= len(gen_ids) else comb_ids[ii] for ii in range(len(comb_ids))]
    sample['gen'] = gen_prompt

    return sample


def run(args, report_to='tensorboard'):
    seed = args['train']['seed']
    np.random.seed(seed)
    torch.manual_seed(seed)

    model_path = args['model']['wildcard']
    print('MODEL:', model_path)
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_path,
        max_seq_length=args['model']['max_seq_length'],
        dtype=None,
        load_in_4bit=True,
        local_files_only=False
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=args['train'].get('lora_r', 32),
        lora_alpha=args['train'].get('lora_alpha', 32),
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=args['train']['seed'],
        use_rslora=False,  # We support rank stabilized LoRA
        loftq_config=None, # And LoftQ
    )

    if args['model'].get('use_bayes_net', False):
        print('Loading Bayes Net....')
        bayesnet = OneDiseaseBayesNet(args['bayes_net'])
        train_dataset = load_data_bayesnet(args['datapath'], 'train', args['model']['history_size'], args['model']['prompt_type'], bayesnet, args['model']['note_type'], args['model'].get('desc_type', 'text'))
        val_dataset = load_data_bayesnet(args['datapath'], 'valid', args['model']['history_size'], args['model']['prompt_type'], bayesnet, args['model']['note_type'], args['model'].get('desc_type', 'text'))

    elif args['model'].get('use_rag', False):
        print('Loading RAG....')
        train_dataset = load_data_rag(args['datapath'], 'train', args['model']['history_size'], args['model']['prompt_type'], args['model']['note_type'], args['model'].get('desc_type', 'text'))
        val_dataset = load_data_rag(args['datapath'], 'valid', args['model']['history_size'], args['model']['prompt_type'], args['model']['note_type'], args['model'].get('desc_type', 'text'))

    else:
        train_dataset = load_data(args['datapath'], 'train', args['model']['history_size'], args['model']['prompt_type'], args['model']['note_type'], args['model'].get('desc_type', 'text'))
        val_dataset = load_data(args['datapath'], 'valid', args['model']['history_size'], args['model']['prompt_type'], args['model']['note_type'], args['model'].get('desc_type', 'text'))

    func = partial(format_func, tokenizer=tokenizer)
    train_dataset = train_dataset.map(func, batched=False)
    val_dataset = val_dataset.map(func, batched=False)
    train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    print('-' * 120)
    print()
    print(tokenizer.decode(train_dataset[5]['input_ids']))
    print()
    print(train_dataset[5]['labels'])
    print('-' * 120)

    train_args = TrainingArguments(
        output_dir=args['destpath'],
        overwrite_output_dir=True,
        remove_unused_columns=True,
        log_level='debug',
        per_device_train_batch_size=args['train']['per_device_train_batch_size'],
        per_device_eval_batch_size=1,
        num_train_epochs=args['train']['num_epochs'],
        learning_rate=args['train']['learning_rate'],
        max_steps=args['train'].get('max_steps', -1),
        save_strategy=args['train'].get('save_strategy', IntervalStrategy.STEPS),
        save_steps=args['train'].get('save_eval_steps', 100),
        seed=args['train']['seed'],
        gradient_checkpointing=args['train'].get('gradient_checkpointing', False),
        # evaluation_strategy=args['train'].get('evaluation_strategy', IntervalStrategy.STEPS),
        eval_strategy=args['train'].get('evaluation_strategy', IntervalStrategy.STEPS),
        eval_steps=args['train'].get('save_eval_steps', 100),
        gradient_accumulation_steps=args['train']['gradient_accumulation_steps'],
        logging_steps=5,
        ddp_find_unused_parameters=False,
        save_total_limit=args['train'].get('save_total_limit', 5),
        load_best_model_at_end=True,
        metric_for_best_model=args['train'].get('metric_for_best_model'),
        greater_is_better=args['train'].get('greater_is_better'),
        report_to=report_to,
        run_name=args['experiment_name'],
        warmup_ratio=args['train'].get('warmup_ratio', 0.0),
        dataloader_drop_last=True,
        lr_scheduler_type=args['train'].get('lr_scheduler', 'constant'),
        group_by_length=args['train'].get('group_by_length', False),
        torch_compile=args['train'].get('torch_compile', False),
        eval_delay=args['train'].get('eval_delay', 0),
        # Unsloth
        optim = "adamw_8bit",
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        weight_decay=0.01,
    )

    func = partial(collate_fn, pad_token_id=tokenizer.pad_token_id)
    trainer = Trainer(
        model=model, args=train_args,
        train_dataset=train_dataset, eval_dataset=val_dataset,
        tokenizer=tokenizer, data_collator=func,
    )

    trainer.model.print_trainable_parameters()
    trainer.train(resume_from_checkpoint=args['train'].get('resume_from_checkpoint', None))


if __name__ == "__main__":
    cargs = read_cli()
    args = get_config(cargs['config'])
    args = override_config(args, cargs)

    local_rank = os.environ.get('LOCAL_RANK', '')
    report_to = 'tensorboard'
    if args.get('use_wandb', False):
        import wandb
        wandb.init()
        report_to='wandb'

    run(args, report_to=report_to)
