import os
import time
import argparse
import torch

from torch import optim 
from torch.utils.data import DataLoader
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    LoraConfig,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
)

import evaluate
from datasets import load_dataset, load_from_disk
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from tqdm import tqdm

import tool
import opt_sw
import logging

SETTING_PATH = '../settings/'

# 1. Read & Set Parameters
args = tool.parse_arg()
tool.setup_logging(args.dataset_name, args.save_path, args.save_name)
tool.set_seed(args.random_seed)


"""settings"""
task = args.dataset_name#"mrpc"
batch_size = args.batch_size
model_name_or_path = "anonymized"
DATA_PATH = f"anonymized/{task}"
METRIC_PATH = f"anonymized/glue"
peft_type = PeftType.LORA
device = torch.device(args.device)
num_epochs = args.n_epoch

peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)


f = open(SETTING_PATH+'{}/{}.txt'.format(args.set_path, args.set_name), 'r')
optimizer_settings = eval(f.read())

logging.warning('[START]')
logging.warning('- dataset name: {}'.format(args.dataset_name))
logging.warning('- random seed : {}'.format(args.random_seed))
logging.warning('- setting path: {}/{}'.format(args.set_path, args.set_name))

eval_losses = {}

for optimizer_name in optimizer_settings:
    eval_losses[optimizer_name] = []
    settings = optimizer_settings[optimizer_name]
    logging.warning(f'\n[Optimizer:{optimizer_name}]')

    """load dataset"""
    if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
        padding_side = "left"
    else:
        padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    datasets = load_dataset("parquet", data_files={
            'train': f'{DATA_PATH}/train-00000-of-00001.parquet',
            'validation': f'{DATA_PATH}/validation-00000-of-00001.parquet',
            'test': f'{DATA_PATH}/test-00000-of-00001.parquet'})
    metric = evaluate.load(METRIC_PATH, task)

    def tokenize_function(examples):
        
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs


    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )


    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")


    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")


    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
    )
    test_dataloader = DataLoader(
        tokenized_datasets["test"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
    )


    """load model"""
    model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
    model = get_peft_model(model, peft_config).to(device)

    """define optimizer and scheduler"""
    optimizer = tool.load_optimizer(model, settings, device)

    """before train"""
    start_time = time.time()
    if type(optimizer) is opt_sw.OptSwitcher:
        optimizer.init(model, args.model_name, train_loader, val_loader)
    init_time = time.time() - start_time
    logging.warning('- init time: {}'.format(init_time))

    """train"""
    start_time = time.time()
    
    for epoch in range(num_epochs):
        model.train()
        for step, batch in enumerate(tqdm(train_dataloader)):
            if type(optimizer) in [opt_sw.OptSwitcher, opt_sw.RandomSwitcher, opt_sw.CyclicalSwitcher]:
                optimizer.recommend_optimizer(model)
            batch.to(device)
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            if type(optimizer) is opt_sw.OptSwitcher:
                optimizer.step(loss.item())
            else:
                optimizer.step()
                
            optimizer.zero_grad()

        model.eval()
        for step, batch in enumerate(tqdm(eval_dataloader)):
            batch.to(device)
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)
            predictions, references = predictions, batch["labels"]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

        eval_metric = metric.compute()
        logging.warning(f"epoch {epoch}: {eval_metric}")

    end_time = time.time()
    logging.warning(f"-ft time: {end_time - start_time}")

    model.eval()
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch.to(device)
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        predictions, references = predictions, batch["labels"]
        metric.add_batch(
            predictions=predictions,
            references=references,
        )

    eval_metric = metric.compute()
    logging.warning(f"- test: {eval_metric}")
