import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    get_linear_schedule_with_warmup,
    set_seed,
)

from peft import LoraConfig, TaskType, get_peft_model
from peft.tuners.lora import LoraLayer
from peft.utils import transpose
import fire
from functools import partial
import evaluate
import bitsandbytes as bnb
from peft import prepare_model_for_int8_training
import math

def main(
    model_name_or_path="EleutherAI/gpt-j-6B",
    fisher_matrix_path="fisher-matrix-6B",
    metric_name_or_path="rouge",
    train_file="train.json",
    val_file="val.json",
    text_column="input",
    label_column="ref",
    lr=1e-3,
    num_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    seed=42,
    max_src_len=800,
    max_tgt_len=256,
    kl_lambda=1,
    num_beams=1,
    output_dir="output",
    lora_r=8,
    lora_alpha=32,
):
      
    #kwargs = DistributedDataParallelKwargs(static_graph=True)
    #accelerator = Accelerator(kwargs_handlers=[kwargs])
    accelerator = Accelerator()
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, 
        inference_mode=False, 
        r=lora_r, 
        lora_alpha=lora_alpha, 
        lora_dropout=0.1
    )
    set_seed(seed)

    def assign_weight(examples, weight):
        examples["weight"] = [weight] * len(examples[text_column])
        return examples

    train_datasets = []
    train_files = train_file.split()
    # If only one dataset is given without weights
    if len(train_files) == 1:
        train_files.insert(0, '1')
    for weight, train_file in zip(train_files[::2], train_files[1::2]):
        train_dataset = load_dataset(
            train_file.split(".")[-1],
            data_files={'train': train_file},
        )['train']
        with accelerator.main_process_first():
            train_dataset = train_dataset.map(
                partial(assign_weight, weight=float(weight)),
                batched=True,
            )
        train_datasets.append(train_dataset)
    train_dataset = concatenate_datasets(train_datasets)

    eval_dataset = load_dataset(val_file.split(".")[-1], data_files={'validation': val_file})['validation']
    dataset = {'train': train_dataset, 'validation': eval_dataset}

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    def preprocess_function(examples, is_train=True):
        assert text_column != 'model_input'
        if is_train:
            tokenizer.padding_side = 'right'
            examples['model_input'] = [
                f'{inp} {ref}\n' for inp, ref in zip(examples[text_column], examples[label_column])
            ]
        else:
            tokenizer.padding_side = 'left'
            examples['model_input'] = examples[text_column]

        batch = tokenizer(
            examples['model_input'],
            max_length=max_src_len,
            padding='max_length',
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt',
        )

        if is_train:
            prefix_weights = tokenizer(
                examples[text_column],
                max_length=max_src_len,
                padding='max_length',
                truncation=True,
                add_special_tokens=False,
                return_tensors='pt',
            ).attention_mask[:, 1:]

            batch['labels'] = batch['input_ids'][:, 1:]
            batch['input_ids'] = batch['input_ids'][:, :-1]
            batch['attention_mask'] = batch['attention_mask'][:, 1:]
            batch['label_weights'] = batch['attention_mask'] * (1 - prefix_weights).float()
            if 'weight' in examples:
                batch['label_weights'] *= torch.tensor(examples['weight'])[:, None]

            # Prepare input for KL loss
            tokenizer.padding_side = 'left' # left padding is required for generation
            kl_batch = tokenizer(
                examples[text_column],
                max_length=max_src_len,
                padding='max_length',
                truncation=True,
                add_special_tokens=False,
                return_tensors='pt',
            )
            batch['kl_input_ids'] = kl_batch['input_ids']
            batch['kl_attention_mask'] = kl_batch['attention_mask']

        return batch

    with accelerator.main_process_first():
        train_dataset = dataset['train'].map(
            partial(preprocess_function, is_train=True),
            batched=True,
            num_proc=1,
            remove_columns=dataset["train"].column_names,
            load_from_cache_file=True,
            desc="Running tokenizer on dataset",
        )
    accelerator.wait_for_everyone()

    with accelerator.main_process_first():
        eval_dataset = dataset['validation'].map(
            partial(preprocess_function, is_train=False),
            batched=True,
            num_proc=1,
            remove_columns=dataset["validation"].column_names,
            load_from_cache_file=False,
            desc="Running tokenizer on dataset",
        )
    accelerator.wait_for_everyone()

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=per_device_train_batch_size, pin_memory=True
    )
    eval_dataloader = DataLoader(
        eval_dataset, collate_fn=default_data_collator, batch_size=per_device_eval_batch_size, pin_memory=True
    )

    #print(next(iter(train_dataloader)))

    # creating model
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_8bit=True, device_map={"": accelerator.local_process_index})
    model = prepare_model_for_int8_training(model)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    ref_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_8bit=True, device_map={"": accelerator.local_process_index})

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # lr scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=(len(train_dataloader) * num_epochs),
    )

    gen_kwargs = {
        'max_new_tokens': max_tgt_len, 
        'num_beams': num_beams,
        'pad_token_id': tokenizer.eos_token_id,
    }

    model, ref_model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, ref_model, train_dataloader, eval_dataloader, optimizer, lr_scheduler
    )
    accelerator.print(model)

    def forward_step(batch):
        labels = batch.pop("labels")
        label_weights = batch.pop("label_weights")
        kl_input_ids = batch.pop("kl_input_ids")
        kl_attention_mask = batch.pop("kl_attention_mask")
        outputs = model(**batch)
        logits = outputs.logits
        ce_loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
        ce_loss = ce_loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        label_weights = label_weights.view(-1)
        label_mask = label_weights > 0
        token_num = torch.sum(label_mask)
        ce_loss = torch.sum(ce_loss * label_weights) / token_num

        # KL Loss
        gen_kwargs = {
            'max_new_tokens': max_tgt_len,
            'do_sample': True,
            'top_p': 0.9,
            'pad_token_id': tokenizer.eos_token_id,
        }
        sequences = accelerator.unwrap_model(model).generate(
            input_ids=kl_input_ids, attention_mask=kl_attention_mask, **gen_kwargs
        )
        sequences = accelerator.pad_across_processes(sequences, dim=1, pad_index=tokenizer.pad_token_id)
        attention_mask = sequences.ne(tokenizer.pad_token_id).long()

        # Calculate position ids ourselves,
        # otherwise left padding for training will cause problem
        position_ids = attention_mask.cumsum(-1) * attention_mask - 1
        position_ids = torch.nn.ReLU(inplace=True)(position_ids) # transform -1 to 0

        logits = model(
            input_ids=sequences, 
            attention_mask=attention_mask, 
            position_ids=position_ids
        ).logits[:, batch['input_ids'].size(-1):]

        with torch.no_grad():
            ref_logits = ref_model(
                input_ids=sequences, 
                attention_mask=attention_mask, 
                position_ids=position_ids
            ).logits[:, batch['input_ids'].size(-1):]
            #ref_log_prob = torch.log_softmax(ref_logits, dim=-1)

        sequences= sequences[:, batch['input_ids'].size(-1):]
        attention_mask = attention_mask[:, batch['input_ids'].size(-1):]

        def get_token_log_probs(logits):
            log_probs = torch.log_softmax(logits, dim=-1)
            token_log_probs = torch.gather(log_probs, dim=2, index=sequences[:, :, None]).squeeze(-1)
            return token_log_probs
        
        log_prob = get_token_log_probs(logits)
        with torch.no_grad():
            ref_log_prob = get_token_log_probs(ref_logits)
        kl_loss = (log_prob.detach() - ref_log_prob) * log_prob
        kl_loss *= attention_mask
        kl_loss = kl_loss.sum() / attention_mask.sum()

        loss = ce_loss + kl_lambda * kl_loss
        outputs.loss = loss
        outputs.ce_loss = ce_loss
        outputs.kl_loss = kl_loss
        return outputs

    eval_scorer = evaluate.load(metric_name_or_path)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, disable=not accelerator.is_local_main_process)
        for step, batch in enumerate(progress_bar):
            outputs = forward_step(batch)
            loss = outputs.loss
            loss_num = loss.detach().cpu().item()
            total_loss += loss_num
            ce_loss = outputs.ce_loss.detach().cpu().item()
            kl_loss = outputs.kl_loss.detach().cpu().item()
            progress_bar.set_description(f"Epoch {epoch} - Loss: {loss_num:.4f}, CE Loss: {ce_loss:.4f}, KL Loss: {kl_loss:.4f}")
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        train_epoch_loss = total_loss / len(train_dataloader)
        accelerator.print(f"{epoch=}: {train_epoch_loss=}")

        model.eval()
        eval_preds = []
        for _, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
            with torch.no_grad():
                outputs = accelerator.unwrap_model(model).generate(**batch, **gen_kwargs) 
            outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id)
            preds = accelerator.gather_for_metrics(outputs)
            preds = preds[:, max_src_len:].detach().cpu().numpy()
            eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))

        assert len(eval_preds) == len(
            dataset['validation'][label_column]
        ), f"{len(eval_preds)} != {len(dataset['validation'][label_column])}"
        eval_preds = [pred.split('\n')[0].strip() for pred in eval_preds]
        scores = eval_scorer.compute(
            predictions=eval_preds,
            references=dataset['validation'][label_column],
            rouge_types=['rouge1', 'rouge2', 'rougeL'],
            use_stemmer=True
        )
        accelerator.print(f"Epoch {epoch} evaluation: rouge1={scores['rouge1']}, rouge2={scores['rouge2']}, rougeL={scores['rougeL']}")

        # saving model
        accelerator.print(f"Saving model to {output_dir}...")
        accelerator.unwrap_model(model).save_pretrained(f'{output_dir}/checkpoint_epoch_{epoch}')

        accelerator.wait_for_everyone()


if __name__ == "__main__":
    fire.Fire(main)
