from transformers import GenerationConfig, AutoTokenizer,AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments
from peft import HOFTConfig, LoraConfig, get_peft_model
from datasets import load_dataset
from evaluate import load
import numpy as np
import argparse
import torch
import wandb
import math
import os



def main(args):

    #Load dataset
    raw_datasets = load_dataset("facebook/covost2", f'{args.src}_{args.tgt}', data_dir=f'covost2/{args.src}', trust_remote_code=True)
    raw_datasets['train'] = raw_datasets['train'].map(remove_columns=["client_id", "file", 'audio', 'id'])
    raw_datasets['validation'] = raw_datasets['validation'].map(remove_columns=["client_id", "file", 'audio', 'id'])
    raw_datasets['test'] = raw_datasets['test'].map(remove_columns=["client_id", "file", 'audio', 'id'])

    max_tok_length = 200
    checkpoint = "facebook/nllb-200-3.3B"

    if args.src == 'sl':
        src_code = 'slv_Latn'
    elif args.src == 'lv':
        src_code = 'lvs_Latn'
    elif args.src == 'de':
        src_code = 'deu_Latn'

        raw_datasets['train'] = raw_datasets['train'].take(10000)
        raw_datasets['validation'] = raw_datasets['validation'].take(1000)
    elif args.src == 'fr':
        src_code = 'fra_Latn'

        raw_datasets['train'] = raw_datasets['train'].take(10000)
        raw_datasets['validation'] = raw_datasets['validation'].take(1000)
    else:
        raise ValueError(f'{args.src_code} is not implemented, implemented src languages are sl and de.')
    
    if args.tgt == 'en':
        tgt_code = "eng_Latn"
    else:
        raise ValueError(f'{args.tgt_code} is not implemented, implemented src languages is en.')
    
    tokenizer = AutoTokenizer.from_pretrained(
        checkpoint, 
        padding=True, 
        pad_to_multiple_of=8, 
        src_lang=src_code, 
        tgt_lang=tgt_code, 
        truncation=True, 
        max_length=max_tok_length,
        )
    
  
    def preprocess_function(sample):
        model_inputs = tokenizer(sample["sentence"], 
            text_target =sample['translation'],
            truncation=True, return_tensors="pt", padding=True, max_length=max_tok_length,
            )
        return model_inputs

    # Tokenize datasets
    tokenized_train_dataset = raw_datasets['train'].map(preprocess_function, batched=True)
    tokenized_val_dataset = raw_datasets['validation'].map(preprocess_function, batched=True)
    tokenized_test_dataset = raw_datasets['test'].map(preprocess_function, batched=True)

    # Obtain model
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
    generation_config = GenerationConfig.from_pretrained(checkpoint)

    # Load metrics
    bleu = load("sacrebleu")
    comet = load("comet")
    chrf = load("chrf")

    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        return preds, labels

    def compute_metrics(eval_preds):
        preds, labels, source = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_source = tokenizer.batch_decode(source, skip_special_tokens=True)

        # Replace negative ids in the labels as we can't decode them.
        #labels = np.where(labels < 0, labels, tokenizer.pad_token_id)
        for i in range(len(labels)):
            labels[i] = [tokenizer.pad_token_id if j<0 else j for j in labels[i]]
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        bleu_result = bleu.compute(predictions=decoded_preds, references=decoded_labels)
        chrf_result = chrf.compute(predictions=decoded_preds, references=decoded_labels)
        comet_result = comet.compute(sources=decoded_source, predictions=decoded_preds, references=decoded_labels)
        result = {"bleu": bleu_result["score"], "chrf": chrf_result["score"], "comet": np.mean(comet_result["scores"]) * 100}
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
        return result
    
    def train_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

        # Replace negative ids in the labels as we can't decode them.
        #labels = np.where(labels < 0, labels, tokenizer.pad_token_id)
        for i in range(len(labels)):
            labels[i] = [tokenizer.pad_token_id if j<0 else j for j in labels[i]]
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        result = bleu.compute(predictions=decoded_preds, references=decoded_labels)
        result = {"bleu": result["score"]}

        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
        result = {k: round(v, 4) for k, v in result.items()}
        return result
    
    if args.use_wandb:
        wandb.init(
            project="MT",
            name=f"nllb-{args.peft_type}-{args.src}-to-{args.tgt}",
            tags=[f"{args.peft_type}", "nllb", f"{args.src}_{args.tgt}"],
        )

    if args.peft_type == 'baseline':
        peft_model = model
    elif args.peft_type == 'hoft':
        config = HOFTConfig(
            task_type="SEQ_2_SEQ_LM",
            r=args.r,
            target_modules=["q_proj", "k_proj", "v_proj"],
            hoft_dropout=args.dropout,
            bias="none",
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'shoft':
        config = HOFTConfig(
            task_type="SEQ_2_SEQ_LM",
            r=args.r,
            target_modules=["q_proj", "k_proj", "v_proj"],
            hoft_dropout=args.dropout,
            bias="none",
            use_shoft = True
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'lora':
        config = LoraConfig(
            task_type="SEQ_2_SEQ_LM",
            r=args.r,
            lora_alpha=args.r * 2,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=args.dropout,
            bias="none",
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'dora':
        config = LoraConfig(
            task_type="SEQ_2_SEQ_LM",
            r=args.r,
            lora_alpha=args.r * 2,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=args.dropout,
            bias="none",
            use_dora=True
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()

    if args.peft_type != 'baseline':
        data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer, 
        model=peft_model, 
        pad_to_multiple_of=8
        )
        train_args = Seq2SeqTrainingArguments(
            f"models/nllb-{args.peft_type}-{args.src}-to-{args.tgt}",
            eval_strategy = "epoch",
            learning_rate=args.lr,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            gradient_accumulation_steps=args.acc_steps,
            weight_decay=0.01,
            save_total_limit=0,
            num_train_epochs=args.epochs,
            predict_with_generate=True,
            optim="adamw_torch",
            warmup_steps=100,
            fp16=True,
            report_to="wandb" if args.use_wandb else None,  # enable logging to W&B
            logging_steps=args.log_steps,  # how often to log to W&B
        )

        from transformers import Seq2SeqTrainer

        trainer = Seq2SeqTrainer(
            peft_model,
            train_args,
            train_dataset=tokenized_train_dataset,
            eval_dataset=tokenized_val_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=train_metrics
        )

        trainer.train()

    test_batch_size = args.test_batch_size
    batch_tokenized_test = tokenized_test_dataset.batch(test_batch_size).with_format("torch")


    number_of_batches = len(batch_tokenized_test["translation"])
    output_sequences = []
    input_sequences = []
    output_labels = []

    for i in range(number_of_batches):

        # Fix mismatched lengths
        if isinstance(batch_tokenized_test[i]["input_ids"], list):
            tensor_type = batch_tokenized_test[i]["input_ids"][0].dtype
            max_length = max(tensor.shape[0] for tensor in batch_tokenized_test[i]["input_ids"])
            max_label_length = max(tensor.shape[0] for tensor in batch_tokenized_test[i]["labels"])

            inputs_ids = torch.stack([torch.cat([tensor, torch.tensor([tokenizer.pad_token_id] * (max_length - tensor.shape[0]))]) for tensor in batch_tokenized_test[i]["input_ids"]]).to(tensor_type)
            attention_mask = torch.stack([torch.cat([tensor, torch.tensor([0] * (max_length - tensor.shape[0]))]) for tensor in batch_tokenized_test[i]["attention_mask"]]).to(tensor_type)
            labels = torch.stack([torch.cat([tensor, torch.tensor([tokenizer.pad_token_id] * (max_label_length - tensor.shape[0]))]) for tensor in batch_tokenized_test[i]["labels"]]).to(tensor_type)
        else:
            inputs_ids = batch_tokenized_test[i]["input_ids"]
            attention_mask = batch_tokenized_test[i]["attention_mask"]
            labels = batch_tokenized_test[i]["labels"]

        with torch.no_grad():
            output_batch = peft_model.generate(
                generation_config=generation_config, 
                input_ids=inputs_ids.cuda(), 
                attention_mask=attention_mask.cuda(), 
                forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code), 
                max_length = max_tok_length, 
                num_beams=1, 
                do_sample=False,
                )
        input_sequences.extend(inputs_ids.cpu())
        output_sequences.extend(output_batch.cpu())
        output_labels.extend(labels if isinstance(labels, list) else labels.cpu() )


    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    result = compute_metrics((output_sequences, output_labels, input_sequences))
    print(f'BLEU score: {result["bleu"]}')
    print(f'chrF score: {result["chrf"]}')
    print(f'COMET score: {result["comet"]}')

    if args.use_wandb:
        # Log BLEU and COMET scores
        wandb.log({
            "test/bleu": math.ceil(result["bleu"] * 100) / 100,
            "test/comet":math.ceil(result["comet"] * 100) / 100,
            "test/chrf": math.ceil(result["chrf"] * 100) / 100
        })
        wandb.finish()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-src", help="source language", type=str, required=True)
    parser.add_argument("-tgt", help="target language", type=str, required=True)
    parser.add_argument("-lr", help="learning rate", type=float, default=1e-4)
    parser.add_argument("-r", help="rank used peft method", type=int, default=0)
    parser.add_argument("-acc_steps", help="number of gradient accumulations", type=int, default=1)
    parser.add_argument("-epochs", help="number of epochs", type=int, default=3)
    parser.add_argument("-batch_size", help="batch size", type=int, default=16)
    parser.add_argument("-test_batch_size", help="test batch size", type=int, default=64)
    parser.add_argument("-dropout", help="dropout for finetuned layers", type=float, default=0.0)
    parser.add_argument("-peft_type", help="select peft method", type=str, default='baseline')
    parser.add_argument('--use_wandb',  help="send results to wandb", action='store_true')
    parser.add_argument("-log_steps", help="number of logging steps", type=int, default=3)


    args = parser.parse_args()
    print(f"""

    NLLB for MT
          - source language : {args.src}
          - target language : {args.tgt}

          - peft type : {args.peft_type}
          - rank (if needed) : {args.r}
          - dropout : {args.dropout}

          - epochs : {args.epochs}
          - learning rate : {args.lr}
          - batch size : {args.batch_size}
          - accumulation steps : {args.acc_steps}

          - use wandb : {args.use_wandb}
          - use logging steps : {args.log_steps}

    """)
    main(args)