import os 
import torch
import transformers 
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import numpy as np
from src.model.fxt import FxtTransformerLM, FxtAverageSingleInputWithPadding
from transformers import AutoTokenizer, LlamaForSequenceClassification
import json
import torch
from safetensors.torch import load_file
from src.data_utils import insert_special_token
import  inspect
import argparse
from datasets import load_metric
import math
from transformers import AdamW, get_scheduler
import torch.optim as optim
from transformers import set_seed
import functools
from src.eval.prepare_task import prepare_xnli_data, prepare_arc_easy_data
from src.utils import compute_mean_with_padding, weights_init
#disable wandb logging
os.environ["WANDB_DISABLED"] = "true"

from datasets import load_dataset
transformers.logging.set_verbosity_error()

accuracy_metric = load_metric("accuracy")

# define args parser
def pargs():
    parser = argparse.ArgumentParser(description='Fine-tune FlexiToken on a downstream task')
    parser.add_argument('--task', type=str, default="xnli", help='The task to fine-tune on')
    parser.add_argument('--model_path', type=str, default=None, help='The path to the model checkpoint')
    parser.add_argument('--lang_id', type=str, default="en", help='The language id to use for the task')
    parser.add_argument('--num_labels', type=int, default=3, help='The number of labels in the downstream task')
    parser.add_argument('--seq_len', type=int, default=512, help='The maximum sequence length')
    parser.add_argument('--output_dir', type=str, default="model_ckpts/oscar_cyrl10x_latin5x_deva13x_baseline_1_bp", help='The output directory for the fine-tuned model')
    parser.add_argument('--run', type=int, default=0, help='The run number for the fine-tuning')
    parser.add_argument('--batch_size', type=int, default=32, help='The batch size for training')   
    parser.add_argument('--num_train_epochs', type=int, default=5, help='The number of training epochs')
    parser.add_argument('--warmup_ratio', type=float, default=0.1, help='The warmup ratio for the learning rate scheduler') 
    parser.add_argument('--logging_steps', type=int, default=500, help='The logging steps')
    parser.add_argument('--eval_steps', type=int, default=5, help='The evaluation steps')
    parser.add_argument('--save_total_limit', type=int, default=2, help='The total number of checkpoints to save')
    parser.add_argument('--save_strategy', type=str, default="epoch", help='The strategy for saving checkpoints')
    parser.add_argument('--evaluation_strategy', type=str, default="epoch", help='The strategy for evaluating checkpoints')
    parser.add_argument('--per_device_train_batch_size', type=int, default=32, help='The training batch size per device')
    parser.add_argument('--per_device_eval_batch_size', type=int, default=128, help='The evaluation batch size per device')
    parser.add_argument('--optim', type=str, default="adamw_torch", help='The optimizer to use')
    parser.add_argument('--report_to', type=str, default=None, help='The report to')
    parser.add_argument('--load_best_model_at_end', type=bool, default=True, help='Whether to load the best model at the end')
    parser.add_argument('--metric_for_best_model', type=str, default='accuracy', help='The metric for the best model')
    parser.add_argument('--greater_is_better', type=bool, default=True, help='Whether greater is better')
    parser.add_argument('--do_predict', type=bool, default=True, help='Whether to do prediction')
    parser.add_argument('--log_level', type=str, default='info', help='The logging level')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='The gradient accumulation steps')
    parser.add_argument('--fp16', type=bool, default=False, help='Whether to use fp16')
    parser.add_argument('--bf16', type=bool, default=True, help='Whether to use bf16')
    parser.add_argument('--dataloader_num_workers', type=int, default=20, help='The number of dataloader workers')
    parser.add_argument('--resume_from_checkpoint', type=str, default=None, help='The checkpoint to resume from')
    parser.add_argument('--lr_scheduler_type', type=str, default='cosine', help='The learning rate scheduler type')
    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='The maximum gradient norm')
    parser.add_argument('--seed', type=int, default=233442, help='The random seed')
    args = parser.parse_args()

    return args


class CustomSequenceClassifier(nn.Module):
    def __init__(self, base_model, num_labels):
        super(CustomSequenceClassifier, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(base_model.final_cast.in_features, num_labels, bias=False)
        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs, stats, boundary_loss = self.base_model(input_ids, target=input_ids.clone(), task="classification")
        hidden_state = outputs.permute(1, 0, 2)
        final_hidden_state = compute_mean_with_padding(hidden_state, attention_mask)
        logits = self.classifier(final_hidden_state)
        
        final_loss = None
        if labels is not None:
            cls_loss = self.loss_fct(logits.view(-1, self.classifier.out_features), labels.view(-1))
            final_loss = cls_loss + boundary_loss[0]
            print(cls_loss, boundary_loss[0])
            quit()

        # stats is a dictionary containing the stats of the model, convert the values to tensors
        stats = {k: torch.tensor(v) for k, v in stats.items()}
        return (final_loss, logits) if final_loss is not None else logits

# load mdoel and tokenizer

def load_model_and_tokenizer(model_config, device):
    if "routing_oscar_cyrl6x_latin3x_deva12x" in model_config["output_dir"]:
        state_dict = torch.load(f"{model_config['output_dir']}/model.pth")
        base_model = FxtTransformerLM(**model_config)
        base_model.load_state_dict(state_dict["model"])
        
    else:
        state_dict = load_file(f"{model_config['output_dir']}/step_60000/model.safetensors")
        base_model = FxtTransformerLM(**model_config)
        base_model.load_state_dict(state_dict)

    print(base_model)

    tokenizer_path = "google/byt5-small"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, extra_ids=0, cache_dir=model_config["cache_dir"],
        additional_special_tokens=model_config["script_tokens"])

    return base_model, tokenizer 
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)  # Get predicted class
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    return {"accuracy": accuracy["accuracy"]}


def main():
    args = pargs()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(args.seed)

   
    # Load the safetensor file
    if args.model_path is None:
        args.model_path = "model_ckpts/oscar_cyrl10x_latin5x_deva13x_baseline_1_bp/_2024-10-28_01-12-56"
    model_config = json.load(open(f"{args.model_path}/config.json"))  
    #set num_predictors in model config:
    model_config["num_predictors"] = 3 if "3_bp" in model_config["output_dir"] else 3
    model_config["cache_dir"] = "/users/project_account/your_username/.cache"
     #load_dataset
    #dataset function mapping
    dataset_functions = {
        "xnli": prepare_xnli_data,
        "arc_easy": prepare_arc_easy_data
    }

    base_model, tokenizer = load_model_and_tokenizer(model_config, device)
    print("Model Loaded Successfully")
    num_labels = args.num_labels
    #add all the args in model_config to args if they are not already there
    args_vars = vars(args)
    for k, v in model_config.items():
        if k not in args_vars:
            setattr(args, k, v)
    model = CustomSequenceClassifier(base_model, num_labels=num_labels)
    model.classifier.apply(functools.partial(weights_init, args=args))
    # model = FxtAverageSingleInputWithPadding(num_labels, base_model)

    model.to(device)
    model.train()
    model.training=True
    print("Classification Model Created Successfully")
    
    load_post_process_data = dataset_functions[args.task]
    train_dataset, val_dataset, test_dataset = load_post_process_data(tokenizer, lang_id=args.lang_id, config=model_config)
    print("Completed dataset processing")
   
    # Set up training arguments
    training_args_params = inspect.signature(TrainingArguments).parameters
    filtered_config = {k: v for k, v in model_config.items() if k in training_args_params}
    filtered_config["output_dir"] = f"{model_config['output_dir']}_fine_tuned/{args.task}/{args.lang_id}/{args.run}"
    #transfer all these args to args above 
    filtered_config["save_total_limit"] = args.save_total_limit
    filtered_config["logging_steps"] = args.logging_steps
    filtered_config["save_strategy"] = args.save_strategy
    filtered_config["evaluation_strategy"] = args.evaluation_strategy
    filtered_config["eval_steps"] = args.eval_steps
    filtered_config["per_device_train_batch_size"] = args.per_device_train_batch_size
    filtered_config["per_device_eval_batch_size"] = args.per_device_eval_batch_size
    filtered_config["num_train_epochs"] = args.num_train_epochs
    filtered_config["optim"] = args.optim
    filtered_config['report_to'] = args.report_to
    filtered_config['load_best_model_at_end'] = args.load_best_model_at_end
    filtered_config['metric_for_best_model'] = args.metric_for_best_model
    filtered_config['greater_is_better'] = args.greater_is_better
    filtered_config['do_predict'] = args.do_predict
    filtered_config['log_level'] = args.log_level
    filtered_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    filtered_config['fp16'] = args.fp16
    filtered_config['bf16'] = args.bf16
    filtered_config['dataloader_num_workers'] = args.dataloader_num_workers
    filtered_config['resume_from_checkpoint'] = args.resume_from_checkpoint
    filtered_config['warmup_ratio'] = args.warmup_ratio
    filtered_config['lr_scheduler_type'] = args.lr_scheduler_type
    filtered_config['max_grad_norm'] = args.max_grad_norm
    filtered_config['learning_rate'] = model_config['lr']
 

    optimizer = optim.Adam(model.parameters(), lr=filtered_config['learning_rate'],
                           betas=(model_config['adam_b1'], model_config['adam_b2']),
                           eps=model_config['adam_eps'],
                           weight_decay=model_config['weight_decay'])

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    model_config['warmup_ratio'] = 0.1
    num_update_steps_per_epoch = math.ceil(len(train_dataset) / model_config['gradient_accumulation_steps'])
    model_config['max_train_steps'] = model_config['num_train_epochs'] * num_update_steps_per_epoch
    overrode_max_train_steps = True

    num_warmup_steps = int(model_config['max_train_steps'] * model_config['warmup_ratio'])

    scheduler = get_scheduler(
        name="cosine",
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=model_config['max_train_steps'])
    


    training_args = TrainingArguments(**filtered_config)

    print("Training Arguments Set Up Successfully")
    print("output_dir:", training_args.output_dir)

    # Set up the Trainer with train and validation datasets
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        optimizers=(optimizer, scheduler)
    )

    # Train, evaluate, and predict on the test dataset
    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_model()
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

    if test_dataset is not None:
        predictions = trainer.predict(test_dataset)
        pred_labels = predictions.predictions.argmax(-1)
        print("Predicted Labels:", pred_labels)
        print("Metrics:", predictions.metrics)
        metrics = predictions.metrics
        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics )


if __name__ == "__main__":
    main()