#--- SET SEED
import torch 
import random
import numpy as np 

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True



#--- LOAD ARGUMENTS
import argparse

def parse_arg():
    parser = argparse.ArgumentParser()

    # base
    parser.add_argument(
        '--task_name', default='mrpc', help='task name')
    parser.add_argument(
        '--model_name', default='llama-3.2-1b', help='model name')
    parser.add_argument(
        '--ft_type', default='full', help='full fine-tuning or PEFT')
    parser.add_argument(
        '--tr_method', default='ours')

    parser.add_argument(
        '--is_save', default=False, type=bool, help='is save the losses')
    parser.add_argument(
        '--device', default='cpu', help='the device, cuda')
    
    # HP set
    parser.add_argument(
        '--random_seed', type=int, default=0, help='for random seed')
    parser.add_argument(
        '--n_epoch',     type=int, default=20, help='for random seed')
    parser.add_argument(
        '--batch_size',  type=int, default=16, help='the batch size')
    
    parser.add_argument(
        '--optimizer_type', default='AdamW', help='optimizer type')
    parser.add_argument(
        '--optimizer_lr', type=float, default=3e-4, help='learning rate')
    
    # HP for our method
    parser.add_argument(
        '--M_max',  type=int, default=-1, help='the maximum of the total memory size, -1 represents no limitation')
    parser.add_argument(
        '--group_size',  type=int, default=1, help='how much layers a group')
    parser.add_argument(
        '--offload_exit_to_cpu', default=False, type=bool, help='whether to offload the exit module in cpu')

    # path
    parser.add_argument(
        '--save_path', default='default', help='the path to save result')
    
    
    
    args, unparsed = parser.parse_known_args()
    return args



#--- SET LOGGING
import os
import logging

RESULT_PATH = '../result/'

def setup_logging(dataset_name, save_path, model_name, ft_type, tr_type, random_seed, is_visual=False):
    formatter = logging.Formatter("%(message)s")
    
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    console_handler.setLevel(level=logging.WARNING)

    logs_path = RESULT_PATH+'{}/{}/{}-{}-{}-seed{}'.format(save_path , dataset_name, 
                                                model_name, ft_type, tr_type, random_seed)

    os.makedirs(os.path.dirname(logs_path), exist_ok=True)
    file_handler = logging.FileHandler(logs_path, mode='a', encoding='utf-8')
    file_handler.setFormatter(formatter)
    file_handler.setLevel(level=logging.INFO)
    
    logging.basicConfig(level=logging.INFO, handlers=[console_handler, file_handler]) 



#--- LOAD DATASET
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer

from torch.utils.data import DataLoader, random_split

def load_data(task='mrpc', model_name='llama-3.2-1b', batch_size=32):
    DATA_PATH    = XXX
    METRIC_PATH  = XXX
    MODEL_PATH = get_model_path(model_name)

    if any(k in MODEL_PATH for k in ("gpt", "opt", "bloom")):
        padding_side = "left"
    else:
        padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side=padding_side)

    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    datasets = load_dataset("parquet", data_files={
            'train': f'{DATA_PATH}/train-00000-of-00001.parquet',
            'validation': f'{DATA_PATH}/validation-00000-of-00001.parquet',
            'test': f'{DATA_PATH}/test-00000-of-00001.parquet'})
    
    def tokenize_function(examples):
        if task in ['mrpc', 'wnli', 'rte', 'stsb']:
            outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding=True, max_length=256)
        elif task in ['qqp']:
            outputs = tokenizer(examples["question1"], examples["question2"], truncation=True, max_length=128)
        elif task in ['qnli']:
            outputs = tokenizer(examples["question"], examples["sentence"], truncation=True, max_length=128)
        elif task in ['ax','mnli']:
            outputs = tokenizer(examples["premise"], examples["hypothesis"], truncation=True, max_length=128)
        elif task in ['cola','sst2']:
            outputs = tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)
        return outputs
    
    def tokenize_function_large(examples):
        if task in ['mrpc', 'wnli', 'rte', 'stsb']:
            outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding=True, max_length=256)
        elif task in ['qqp']:
            outputs = tokenizer(examples["question1"], examples["question2"], truncation=True, max_length=128)
        elif task in ['qnli']:
            outputs = tokenizer(examples["question"], examples["sentence"], truncation=True, max_length=128)
        elif task in ['ax','mnli']:
            outputs = tokenizer(examples["premise"], examples["hypothesis"], truncation=True, max_length=128)
        elif task in ['cola','sst2']:
            outputs = tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)
        return outputs

    if '7b' in model_name:
        if task in ['mrpc', 'wnli', 'rte', 'stsb']:
            tokenized_datasets = datasets.map(
                tokenize_function_large,
                batched=True,
                remove_columns=["idx", "sentence1", "sentence2"],
            )
        elif task in ['qqp']:
            tokenized_datasets = datasets.map(
                tokenize_function_large,
                batched=True,
                remove_columns=["idx", "question1", "question2"],
            )
        elif task in ['qnli']:
            tokenized_datasets = datasets.map(
                tokenize_function_large,
                batched=True,
                remove_columns=["idx", "question", "sentence"],
            )
        elif task in ['ax','mnli']:
            tokenized_datasets = datasets.map(
                tokenize_function_large,
                batched=True,
                remove_columns=["idx", "premise", "hypothesis"],
            )
        elif task in ['cola','sst2']:
            tokenized_datasets = datasets.map(
                tokenize_function_large,
                batched=True,
                remove_columns=["idx", "sentence"],
            )
    
    else:
        if task in ['mrpc', 'wnli', 'rte', 'stsb']:
            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=["idx", "sentence1", "sentence2"],
            )
        elif task in ['qqp']:
            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=["idx", "question1", "question2"],
            )
        elif task in ['qnli']:
            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=["idx", "question", "sentence"],
            )
        elif task in ['ax','mnli']:
            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=["idx", "premise", "hypothesis"],
            )
        elif task in ['cola','sst2']:
            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=["idx", "sentence"],
            )

    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    if task in ['mrpc']:
        train_dataloader = DataLoader(
            tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
        )
        eval_dataloader = DataLoader(
            tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
        )
        test_dataloader = DataLoader(
            tokenized_datasets["test"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
        )
    elif task in ['wnli', 'qqp', 'rte', 'stsb', 'qnli', 'mnli', 'cola','sst2']:
        n_train = len(tokenized_datasets['train'])
        n_val   = int(0.3*n_train)
        train_set, val_set = random_split(tokenized_datasets["train"], [n_train-n_val, n_val])

        train_dataloader = DataLoader(
            train_set, shuffle=False, collate_fn=collate_fn, batch_size=batch_size
        )
        eval_dataloader = DataLoader(
            val_set, shuffle=False, collate_fn=collate_fn, batch_size=batch_size
        )
        test_dataloader = DataLoader(
            tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
        )

    metric = evaluate.load(METRIC_PATH, task)

    return train_dataloader, eval_dataloader, test_dataloader, metric



#--- MODEL PATH
def get_model_path(model_name):
XXX
    return MODEL_PATH



#--- LOAD MODEL
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model

def load_model(model_name='llama-3.2-1b', task_name="wnli", train_type='lora', train_setting={}):
    MODEL_PATH = get_model_path(model_name)

    if task_name in ['mnli']:
        num_labels, problem_type = 3, None
    elif task_name in ['stsb']:
        num_labels, problem_type = 1, "regression"
    else:  # mrpc, wnli, rte, qqp, qnli, cola, sst2, etc.
        num_labels, problem_type = 2, None

    config = AutoConfig.from_pretrained(MODEL_PATH)
    config.num_labels = num_labels
    if problem_type is not None:
        config.problem_type = problem_type
    config.use_cache = False  

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if getattr(config, "pad_token_id", None) is None:
        config.pad_token_id = tokenizer.pad_token_id

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_PATH,
        config=config,
    )

    if model.get_input_embeddings().num_embeddings != len(tokenizer):
        model.resize_token_embeddings(len(tokenizer))

    if train_type == 'lora':
        peft_config = LoraConfig(task_type="SEQ_CLS", **train_setting)
        model = get_peft_model(model, peft_config)
        model.config.pad_token_id = tokenizer.pad_token_id
        return model
    elif train_type == 'full':
        model.config.pad_token_id = tokenizer.pad_token_id
        return model
    else:
        model.config.pad_token_id = tokenizer.pad_token_id
        return model


#--- LOAD OPTIMIZER
from torch import optim

def load_optimizer(model, optimizer_type, optimizer_lr):
    if optimizer_type=='AdamW':
        return optim.AdamW(params=model.parameters(), lr=optimizer_lr)
    elif optimizer_type=='SGD':
        return optim.SGD(  params=model.parameters(), lr=optimizer_lr)
    elif optimizer_type=='Adam':
        return optim.Adam( params=model.parameters(), lr=optimizer_lr, betas=(0.9, 0.999), eps=1e-8)
    else:
        return optim.AdamW(params=model.parameters(), lr=optimizer_lr)



#--- TRAIN MODEL
import time
from tqdm import tqdm
from transformers import get_scheduler

import ours
import math

def train_model(model, optimizer, args, 
                train_dataloader, device='cpu'):
    
    # arguments
    n_epoch     = args.n_epoch
    ft_type     = args.ft_type
    tr_method   = args.tr_method
    random_seed = args.random_seed 

    our_args    = {'M_max': args.M_max,
                   'group_size': args.group_size,
                   'offload_exit_to_cpu': args.offload_exit_to_cpu,
                   'device': device}

    # scheduler
    total_steps = n_epoch * len(train_dataloader)
    warmup_steps = int(math.floor(total_steps * 0.06))

    lr_scheduler = get_scheduler(
        name="linear",              
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    if tr_method == 'ours':
        trims = ours.TriMS(model, optimizer, **our_args)

    # train
    results = {'losses':[]} 

    start_time = time.time()
    model.to(device)

    for epoch in tqdm(range(n_epoch), leave=False, desc='Training'):
        
        model.train()

        for step, batch in enumerate(tqdm(train_dataloader, leave=False, desc='Batch')):
             
            batch.to(device)

            optimizer.zero_grad()
            outputs = model(**batch)
            loss    = outputs.loss
            results['losses'].append(loss.cpu().item())
            
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            
            if tr_method == 'ours':
                trims.step(loss=loss.item())
    
    if tr_method == 'ours':
        trims.remove_activation_hooks()

    train_time = time.time()-start_time
    logging.warning(f'- train time: {train_time}')
    results['train_time'] = train_time

    return model, results



#--- TEST MODEL
def test_model(model, metric, data_loader, device):
    model.eval()
    model.to(device)
    test_loss = 0.

    with torch.no_grad():
        for batch in tqdm(data_loader, leave=False, desc='Testing'):
            batch.to(device)
            outputs = model(**batch)
            test_loss += outputs.loss.cpu().item()
            
            if model.config.problem_type == "regression":
                predictions = outputs.logits
            else:
                predictions = outputs.logits.argmax(dim=-1)
                
            predictions, references = predictions, batch["labels"]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

    test_metric = metric.compute()
    return test_metric

