import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    AdamW, get_linear_schedule_with_warmup,
    DebertaV2Tokenizer, TrainingArguments
)

from datasets import DatasetDict 
from datasets import Dataset as dDataset
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from datasets import load_from_disk
from tqdm.auto import tqdm
import gc
import os
import argparse
import json
import matplotlib.pyplot as plt
from datetime import datetime
import seaborn as sns
from pathlib import Path
from compressor.compressor import Compressor
from compressor.compress_methods import compress_topk, quantize_dequantize_ACSGD
import yaml
import random
from dataset.wikitext import LazyWikitext, AQSGD_Wikitext

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

parser = argparse.ArgumentParser(description='Glue FineTune Test')
parser.add_argument('--datasets', nargs='+', 
                    choices=['cola', 'qnli', 'wnli', 'sst-2', 'mnli', 'qqp', 'sts-b', 'ax', 'mrpc', 'all'], 
                    default=['cola'],
                    )
parser.add_argument('--batch_size', type=int, default=16, )
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=1e-5)
parser.add_argument('--max_len', type=int, default=512)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--log_dir', type=str, default='./logs')
parser.add_argument('--use_tensorboard', action='store_true')
parser.add_argument('--use_compression', action='store_true')
parser.add_argument('--compression_layer', type=int, default=12)
parser.add_argument('--compression_ratio', type=float, default=0.3)
parser.add_argument('--config_path', type=str, default='glue_finetune_config.yaml')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"use device: {device}")

# HyperParameter
MAX_LEN = args.max_len
BATCH_SIZE = args.batch_size
ACCUMULATION_STEPS = 4
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate
WARMUP_RATIO = 0.1
SEED = args.seed
FP16 = True  
USE_TENSORBOARD = args.use_tensorboard and TENSORBOARD_AVAILABLE

# create log dir 
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = Path(args.log_dir) / timestamp
log_dir.mkdir(parents=True, exist_ok=True)
# set seed
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# use TensorBoard
if USE_TENSORBOARD:
    tb_log_dir = log_dir / "tensorboard"
    tb_log_dir.mkdir(parents=True, exist_ok=True)
    tb_writer = SummaryWriter(log_dir=tb_log_dir)
else:
    tb_writer = None  

history = {}

MODEL_CONFIGS = {
    # "DeBERTa-1.5B": {
    #     "model_name": "/data/pretrained_models/deberta-v2-xxlarge",
    #     "tokenizer_name": "/data/pretrained_models/deberta-v2-xxlarge", 
    # },
    "RoBERTa-1.5B": {
        "model_name": "/data/pretrained_models/roberta-large",
        "tokenizer_name": "/data/pretrained_models/roberta-large",
    }
}

# set data path 
DATASET_CONFIGS = {
    "cola": {
        "path": "/data/datasets/glue/cola",
        "num_labels": 2,
        "metrics": ["accuracy", "matthews_correlation"]
    },
    "qnli": {
        "path": "/data/datasets/glue/qnli",
        "num_labels": 2,
        "metrics": ["accuracy"]
    },
    "wnli": {
        #"path": "/data/datasets/glue/wnli",
        # use wsc
        'path': "/data/datasets/superglue/wsc",
        "num_labels": 2,
        "metrics": ["accuracy"]
    },
    "sst-2": {
        "path": "/data/datasets/glue/sst-2",
        "num_labels": 2,
        "metrics": ["accuracy"]
    },
    "mnli": {
        "path": "/data/datasets/glue/mnli",
        "num_labels": 3,
        "metrics": ["accuracy"]
    },
    "qqp": {
        "path": "/data/datasets/glue/qqp",
        "num_labels": 2,
        "metrics": ["accuracy", "f1"]
    },
    "sts-b": {
        "path": "/data/datasets/glue/sts-b",
        "num_labels": 1,
        "metrics": ["pearson", "spearman"]
    },
    "ax": {
        "path": "/data/datasets/glue/ax",
        "num_labels": 3,
        "metrics": ["accuracy", "matthews_correlation"]
    },
    "mrpc": {
        "path": "/data/datasets/glue/mrpc",
        "num_labels": 2,
        "metrics": ["accuracy", "f1"]
    }
}

class GLUEDataset(Dataset):
    def __init__(self, encodings, labels=None, with_idx=False):
        self.encodings = encodings
        self.labels = labels
        self.with_idx = with_idx

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels is not None:
            item['labels'] = torch.tensor(self.labels[idx])
        if self.with_idx:
            item['indices'] = idx
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

def log_metrics(phase, dataset_name, model_name, epoch, step, metrics):
    key = f"{phase}/{dataset_name}/{model_name}"
    if key not in history:
        history[key] = {metric_name: [] for metric_name in metrics.keys()}
        history[key]['epochs'] = []
        history[key]['steps'] = []
    
    for metric_name, metric_value in metrics.items():
        history[key][metric_name].append(metric_value)
    
    if phase == "train":
        history[key]['steps'].append(step)
        if step % 100 == 0 or step == 0:
            history[key]['epochs'].append(epoch)
    else:
        history[key]['epochs'].append(epoch)
        history[key]['steps'].append(step)
    if USE_TENSORBOARD and tb_writer is not None:
        for metric_name, metric_value in metrics.items():
            tb_writer.add_scalar(f"{phase}/{dataset_name}/{model_name}/{metric_name}", 
                                metric_value, step)
            tb_writer.flush()
    
    with open(log_dir / "training_history.json", "w") as f:
        json.dump(history, f, indent=4)
    
    if phase == "val":
        epoch_val_file = log_dir / f"{dataset_name}_{model_name}_epoch{epoch+1}_val.csv"
        with open(epoch_val_file, 'w') as f:
            f.write("metric,value\n")
            for metric_name, metric_value in metrics.items():
                f.write(f"{metric_name},{metric_value}\n")

class BertCompression:
    def __init__(self, model, compression_config):
        self.model = model
        self.config = compression_config
        self.compressors = {}
        self.hook_handles = []
        
    def apply_compression(self):
            return
            
        layers = self.config.get('layers', [])
        method = self.config.get('method', 'topk')
        params = self.config.get('params', {})
        
        forward_EF = self.config.get('forward_EF', True)
        forward_EF_method = self.config.get('forward_EF_method', "EF21")
        backward = self.config.get('backward', method)
        backward_params = self.config.get('backward_params', params)
        backward_EF = self.config.get('backward_EF', True)
        backward_EF_method = self.config.get('backward_EF_method', "EF21")
        
        
        for layer_idx in layers:
            hook_fn = self._create_compression_hook(
                layer_idx, 
                method, 
                params,
                forward_EF,
                forward_EF_method,
                backward,
                backward_params,
                backward_EF,
                backward_EF_method
            )
            
            if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'layer'):
                handle = self.model.encoder.layer[layer_idx].register_forward_hook(hook_fn)
            elif hasattr(self.model, 'deberta') and hasattr(self.model.deberta, 'encoder') and hasattr(self.model.deberta.encoder, 'layer'):
                handle = self.model.deberta.encoder.layer[layer_idx].register_forward_hook(hook_fn)
            elif hasattr(self.model, 'roberta') and hasattr(self.model.roberta, 'encoder') and hasattr(self.model.roberta.encoder, 'layer'):
                handle = self.model.roberta.encoder.layer[layer_idx].register_forward_hook(hook_fn)
            elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
                handle = self.model.transformer.h[layer_idx].register_forward_hook(hook_fn)
            else:
                continue
                
            self.hook_handles.append(handle)
    
    def remove_compression(self):
        for handle in self.hook_handles:
            handle.remove()
            del handle
        self.hook_handles = []
    
    def _create_compression_hook(self, layer_idx, method, params, 
                               forward_EF, forward_EF_method,
                               backward, backward_params,
                               backward_EF, backward_EF_method):
        self.compressors[layer_idx] = None
        
        def compression_hook(module, input_tensor, output):
            if self.compressors[layer_idx] is None:
                if isinstance(output, tuple):
                    tensor_shape = output[0].shape
                else:
                    tensor_shape = output.shape
                    
                self.compressors[layer_idx] = Compressor(
                    input_shape=tensor_shape,
                    forward=method,
                    forward_params=params,
                    backward=backward,
                    backward_params=backward_params,
                    forward_EF=forward_EF,
                    backward_EF=backward_EF,
                    forward_EF_method=forward_EF_method,
                    backward_EF_method=backward_EF_method
                )
            
            indices = getattr(module, 'current_indices', None)
            
            if isinstance(output, tuple):
                compressed = self.compressors[layer_idx](output[0], indices=indices)
                return (compressed,) + output[1:]
            else:
                return self.compressors[layer_idx](output, indices=indices)
        
        return compression_hook

def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, device, epochs, model_name, dataset_name):
    model.to(device)
    
    config = load_config(args.config_path)
    compression_config = {
        'enabled': args.use_compression,
        'method': 'topk',
        'layers': [args.compression_layer],
        'params': {'topk': args.compression_ratio}
    }
    
    if 'compression_config' in config:
        for layer_key, layer_config in config['compression_config'].items():
            if layer_key.startswith('layer'):
                layer_idx = layer_config.get('layer_idx', args.compression_layer)
                compression_config['layers'] = [layer_idx]
                compression_config['method'] = layer_config.get('forward', 'topk')
                compression_config['params'] = layer_config.get('forward-params', {'topk': 0.5})
                compression_config['forward_EF'] = layer_config.get('forward-EF', True)
                compression_config['forward_EF_method'] = layer_config.get('forward-EF-method', "EF21")
                compression_config['backward'] = layer_config.get('backward', 'topk')
                compression_config['backward_params'] = layer_config.get('backward-params', {'topk': 0.5})
                compression_config['backward_EF'] = layer_config.get('backward-EF', True)
                compression_config['backward_EF_method'] = layer_config.get('backward-EF-method', "EF21")
                break
    
    
    compressor = BertCompression(model, compression_config)
    if args.use_compression:
        compressor.apply_compression()
    
    train_losses = []
    val_metrics_history = []
    global_step = 0
    
    epoch_metrics_df = pd.DataFrame(columns=['epoch', 'train_loss'])
    
    scaler = torch.amp.GradScaler() if FP16 and torch.cuda.is_available() else None
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        epoch_losses = []
        
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", leave=False)
        optimizer.zero_grad()  
        
        for step, batch in enumerate(progress_bar):
            indices = None
            if 'indices' in batch:
                indices = batch.pop('indices')
                
                if args.use_compression:
                    for layer_idx in compression_config['layers']:
                        layer_path = None
                        if hasattr(model, 'encoder') and hasattr(model.encoder, 'layer'):
                            layer_path = model.encoder.layer[layer_idx]
                        elif hasattr(model, 'roberta') and hasattr(model.roberta, 'encoder'):
                            layer_path = model.roberta.encoder.layer[layer_idx]
                        elif hasattr(model, 'deberta') and hasattr(model.deberta, 'encoder'):
                            layer_path = model.deberta.encoder.layer[layer_idx]
                            
                        if layer_path is not None:
                            setattr(layer_path, 'current_indices', indices)
            
            batch = {k: v.to(device) for k, v in batch.items()}
            
            if scaler is not None:
                with torch.amp.autocast(device_type='cuda'):
                    outputs = model(**batch)
                    loss = outputs.loss / ACCUMULATION_STEPS  
                scaler.scale(loss).backward()
                torch.cuda.empty_cache()
                
                if (step + 1) % ACCUMULATION_STEPS == 0 or (step + 1) == len(train_dataloader):
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()
            else:
                outputs = model(**batch)
                loss = outputs.loss / ACCUMULATION_STEPS  
                loss.backward()
                
                if (step + 1) % ACCUMULATION_STEPS == 0 or (step + 1) == len(train_dataloader):
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
            
            loss_value = loss.item() * ACCUMULATION_STEPS
            train_loss += loss_value
            epoch_losses.append(loss_value)
            progress_bar.set_postfix({'loss': loss_value})
            
            if global_step % 100 == 0:
                log_metrics("train", dataset_name, model_name, epoch, global_step, {"loss": loss_value})
                
            global_step += 1
            
            if indices is not None and args.use_compression:
                for layer_idx in compression_config['layers']:
                    layer_path = None
                    if hasattr(model, 'encoder') and hasattr(model.encoder, 'layer'):
                        layer_path = model.encoder.layer[layer_idx]
                    elif hasattr(model, 'roberta') and hasattr(model.roberta, 'encoder'):
                        layer_path = model.roberta.encoder.layer[layer_idx]
                    elif hasattr(model, 'deberta') and hasattr(model.deberta, 'encoder'):
                        layer_path = model.deberta.encoder.layer[layer_idx]
                        
                    if layer_path is not None and hasattr(layer_path, 'current_indices'):
                        delattr(layer_path, 'current_indices')
        
        avg_train_loss = train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        
        log_metrics("epoch_train", dataset_name, model_name, epoch, global_step, {"loss": avg_train_loss})
        
        if args.use_compression:
            compressor.remove_compression()
            del compressor
            torch.cuda.empty_cache()
        
        val_metrics = evaluate_model(model, val_dataloader, device, dataset_name=dataset_name)
        val_metrics_history.append(val_metrics)
        
        log_metrics("val", dataset_name, model_name, epoch, global_step, val_metrics)
        
        epoch_data = {'epoch': epoch + 1, 'train_loss': avg_train_loss}
        epoch_data.update({f'val_{k}': v for k, v in val_metrics.items()})
        epoch_metrics_df = pd.concat([epoch_metrics_df, pd.DataFrame([epoch_data])], ignore_index=True)
        
        csv_file = log_dir / f'{model_name}_{dataset_name}_all_epochs_metrics.csv'
        epoch_metrics_df.to_csv(csv_file, index=False)
        
        epoch_results_file = log_dir / f'{model_name}_{dataset_name}_epoch{epoch+1}_results.json'
        with open(epoch_results_file, 'w') as f:
            epoch_results = {
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_metrics": val_metrics
            }
            json.dump(epoch_results, f, indent=4)
        
        if args.use_compression:
            compressor = BertCompression(model, compression_config)
            compressor.apply_compression()
    
    plot_training_process(model_name, dataset_name, train_losses, val_metrics_history)
    
    return None, val_metrics_history[-1]

def evaluate_model(model, dataloader, device, dataset_name="cola"):
    model.eval()
    predictions = []
    true_labels = []
    raw_outputs = [] 
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            if 'labels' in batch:
                if dataset_name != "sts-b":
                    num_classes = model.config.num_labels
                    batch['labels'] = torch.clamp(batch['labels'], 0, num_classes - 1)
            
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            
            if dataset_name == "sts-b": 
                preds = logits.squeeze(-1).cpu().numpy()
                raw_outputs.extend(preds)
            else:  
                preds = torch.argmax(logits, dim=1).cpu().numpy()
            
            labels = batch['labels'].cpu().numpy()
            
            predictions.extend(preds)
            true_labels.extend(labels)
    
    metrics = {}
    
    if dataset_name == "sts-b":
        from scipy.stats import pearsonr, spearmanr
        pearson_corr, _ = pearsonr(true_labels, predictions)
        spearman_corr, _ = spearmanr(true_labels, predictions)
        metrics["pearson"] = float(pearson_corr)
        metrics["spearman"] = spearman_corr
    else:
        metrics["accuracy"] = accuracy_score(true_labels, predictions)
        
        if dataset_name in ["cola", "ax"]:
            metrics["matthews_correlation"] = matthews_corrcoef(true_labels, predictions)
        elif dataset_name in ["qnli", "wnli", "sst-2", "qqp"]:
            metrics["f1"] = f1_score(true_labels, predictions, average='binary')
        elif dataset_name == "mnli":
            pass
    
    return metrics

def plot_training_process(model_name, dataset_name, train_losses, val_metrics_history):
    plt.figure(figsize=(12, 8))
    
    plots_dir = log_dir / "plots"
    plots_dir.mkdir(exist_ok=True)
    
    plt.subplot(2, 1, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Training Loss')
    plt.title(f'{model_name} on {dataset_name} - Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(2, 1, 2)
    for metric in val_metrics_history[0].keys():
        metric_values = [metrics[metric] for metrics in val_metrics_history]
        plt.plot(range(1, len(metric_values) + 1), metric_values, label=f'Validation {metric}')
    
    plt.title(f'{model_name} on {dataset_name} - Validation Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(plots_dir / f"{model_name}_{dataset_name}_training_process.png")
    plt.close()

def get_column_names(dataset, dataset_name):
    
    column_mappings = {
        "cola": {"sentence": "sentence", "label": "label"},
        "qnli": {"sentence1": "question", "sentence2": "sentence", "label": "label"},
        #"wnli": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
        "wnli": {"sentence1": "text", "sentence2": "span1_text", "label": "label"},
        "sst-2": {"sentence": "sentence", "label": "label"},
        "mnli": {"sentence1": "premise", "sentence2": "hypothesis", "label": "label"},
        "qqp": {"sentence1": "question1", "sentence2": "question2", "label": "label"},
        "sts-b": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
        "ax": {"sentence1": "premise", "sentence2": "hypothesis", "label": "label"},
        "mrpc": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
    }
    
    if dataset_name in column_mappings:
        mapping = column_mappings[dataset_name]
        
    sentence_column, second_column, label_column = None, None, None
    torch.cuda.empty_cache()

    for col in ['sentence', 'sentence1', 'question', 'text', 'premise', 'question1']:
        if col in dataset['train'].features:
            sentence_column = col
            break
    
    for col in ['sentence2', 'hypothesis', 'question2']:
        if col in dataset['train'].features and col != sentence_column:
            second_column = col
            break
    
    for col in ['label', 'labels', 'class', 'target']:
        if col in dataset['train'].features:
            label_column = col
            break
    
    
def run_model_evaluation(model_name, model_config, dataset, dataset_name):
    global current_dataset
    current_dataset = dataset_name
    
    config = load_config(args.config_path)
    training_config = config.get('training', {})
    
    
    try:
        has_sentence_pair = False
        if "train" in dataset:
            
            if dataset_name == "qnli":
                sentence_column, second_column, label_column = "question", "sentence", "label"
                has_sentence_pair = True
            else:
                for col in ['sentence2', 'hypothesis', 'question2', 'span1_text']:
                    if col in dataset['train'].features:
                        has_sentence_pair = True
                        break
                
        if has_sentence_pair:
            if dataset_name == "qnli":
                sentence_column, second_column, label_column = "question", "sentence", "label"
            else:
                sentence_column, second_column, label_column = get_column_names(dataset, dataset_name)
        else:
            sentence_column, label_column = get_column_names(dataset, dataset_name)
    except Exception as e:
        if dataset_name == "qnli":
            sentence_column, second_column, label_column = "question", "sentence", "label"
        elif dataset_name in ["wnli", "mnli", "qqp", "sts-b", "ax"]:
            sentence_column, second_column, label_column = "sentence1", "sentence2", "label"
        else:
            sentence_column, label_column = "sentence", "label"
    
    if "DeBERTa" in model_name:
        try:
            tokenizer = DebertaV2Tokenizer.from_pretrained(
                model_config['tokenizer_name'],
                local_files_only=True
            )
        except Exception as e:
            tokenizer = AutoTokenizer.from_pretrained(
                model_config['tokenizer_name'],
                local_files_only=True,
                use_fast=False
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            model_config['tokenizer_name'],
            local_files_only=True
        )
    
    num_labels = DATASET_CONFIGS[dataset_name]["num_labels"]
    model = AutoModelForSequenceClassification.from_pretrained(
        model_config['model_name'],
        local_files_only=True, 
        num_labels=num_labels
    )
    
    try:
        if has_sentence_pair:
            train_encodings = tokenizer(
                dataset['train'][sentence_column], 
                dataset['train'][second_column],
                truncation=True, 
                padding='max_length', 
                max_length=MAX_LEN
            )
            val_encodings = tokenizer(
                dataset['validation'][sentence_column], 
                dataset['validation'][second_column],
                truncation=True, 
                padding='max_length', 
                max_length=MAX_LEN
            )
        else:
            train_encodings = tokenizer(
                dataset['train'][sentence_column], 
                truncation=True, 
                padding='max_length', 
                max_length=MAX_LEN
            )
            val_encodings = tokenizer(
                dataset['validation'][sentence_column], 
                truncation=True, 
                padding='max_length', 
                max_length=MAX_LEN
            )
        
        train_dataset = GLUEDataset(train_encodings, dataset['train'][label_column], with_idx=True)
        val_dataset = GLUEDataset(val_encodings, dataset['validation'][label_column])
        
        if training_config.get('lazy_sampling', False):
            lazy_params = training_config.get('lazy_sampling_params', {})
            if lazy_params.get('schedule') == 'constant':
                p_t = lazy_params.get('p_t', 0.5)
                fixed_p_t = lambda x: p_t
                train_dataset = LazyWikitext(
                    train_dataset,
                    p_t=fixed_p_t,
                    batch_size=training_config.get('batch_size', BATCH_SIZE),
                    with_idx=True
                )
        elif training_config.get('aq_sgd', False):
            train_dataset = AQSGD_Wikitext(
                train_dataset,
                batch_size=training_config.get('batch_size', BATCH_SIZE),
                with_idx=True
            )
        
        def indexed_collator(features):
            batch = {}
            
            for key in features[0].keys():
                if key != 'indices':
                    if isinstance(features[0][key], torch.Tensor):
                        batch[key] = torch.stack([f[key] for f in features])
                    else:
                        batch[key] = [f[key] for f in features]
            
            if 'indices' in features[0]:
                batch['indices'] = [f['indices'] for f in features]
            
            return batch
        
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=BATCH_SIZE, 
            shuffle=True,
            collate_fn=indexed_collator,
            drop_last=True
        )
        val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, drop_last=True)

        optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
        total_steps = len(train_dataloader) * EPOCHS // ACCUMULATION_STEPS
        warmup_steps = int(total_steps * WARMUP_RATIO)
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer, 
            num_warmup_steps=warmup_steps, 
            num_training_steps=total_steps
        )
        
        _, val_metrics = train_model(
            model=model,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epochs=EPOCHS,
            model_name=model_name,
            dataset_name=dataset_name
        )
        
        log_metrics("final_val", dataset_name, model_name, EPOCHS-1, total_steps, val_metrics)
        for metric_name, metric_value in val_metrics.items():
            print(f"{metric_name}: {metric_value:.4f}")
        
        del model
        torch.cuda.empty_cache()
        gc.collect()
        
        return {
            'model': model_name,
            'dataset': dataset_name,
            'val_metrics': val_metrics
        }
    except Exception as e:
        raise

def main():
    all_results = []
    
    datasets_to_evaluate = []
    if 'all' in args.datasets:
        datasets_to_evaluate = list(DATASET_CONFIGS.keys())
    else:
        datasets_to_evaluate = args.datasets
    
    
    log_dir.mkdir(parents=True, exist_ok=True)
    
    if os.path.exists(args.config_path):
        import shutil
        config_dest = log_dir / os.path.basename(args.config_path)
        shutil.copy2(args.config_path, config_dest)
    
    config_info = {
        "max_len": MAX_LEN,
        "batch_size": BATCH_SIZE,
        "accumulation_steps": ACCUMULATION_STEPS,
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE,
        "warmup_ratio": WARMUP_RATIO,
        "seed": SEED,
        "fp16": FP16,
        "device": str(device),
        "datasets": datasets_to_evaluate,
        "timestamp": timestamp
    }
    with open(log_dir / "config.json", "w") as f:
        json.dump(config_info, f, indent=4)
    
    for dataset_name in datasets_to_evaluate:
        if dataset_name not in DATASET_CONFIGS:
            continue
            
        try:
            dataset_path = DATASET_CONFIGS[dataset_name]["path"]
            if dataset_name != 'ax':
                dataset = load_from_disk(dataset_path)
            else:
                file_path = os.path.join(dataset_path, 'diagnostic-full.tsv')
                df = pd.read_csv(file_path, sep='\t', usecols=['Premise', 'Hypothesis', 'Label'])
                df['Label'] = df['Label'].map({'contradiction':0, 'neutral':1, 'entailment':2})
                df.columns = ['premise', 'hypothesis', 'label']
                dataset = dDataset.from_pandas(df)
                train_dataset = dataset.shuffle()
                vali_dataset = dataset.shuffle()
                dataset = DatasetDict({"train": train_dataset, "validation": vali_dataset})
            
            
            for model_name, model_config in MODEL_CONFIGS.items():
                try:
                    result = run_model_evaluation(model_name, model_config, dataset, dataset_name)
                    all_results.append(result)
                    
                    results_file = log_dir / f'{model_name}_{dataset_name}_results.csv'
                    pd.DataFrame([result]).to_csv(results_file, index=False)
                    
                except Exception as e:
                    print(e)
        except Exception as e:
            print(e)
    
    if all_results:
        for dataset_name in datasets_to_evaluate:
            dataset_results = [r for r in all_results if r['dataset'] == dataset_name]
            if dataset_results:
                for r in dataset_results:
                    for metric_name, metric_value in r['val_metrics'].items():
                        formatted_value = float(metric_value) * 100 
                        print(f"{metric_name}: {formatted_value:.2f}")
                
                results_file = log_dir / f'{dataset_name}_final_results.txt'
                with open(results_file, 'w') as f:
                    f.write(f"Dataset: {dataset_name}\n")
                    f.write(f"Model: {dataset_results[0]['model']}\n")
                    f.write("=== Evaulation result ===\n")
                    for metric_name, metric_value in dataset_results[0]['val_metrics'].items():
                        formatted_value = float(metric_value) * 100
                        f.write(f"{metric_name}: {formatted_value:.2f}\n")
                print(f"result has been saved to {results_file}")
    
    if USE_TENSORBOARD and tb_writer is not None:
        tb_writer.flush()  
        tb_writer.close()
    

def load_config(config_path="glue_finetune_config.yaml"):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


if __name__ == "__main__":
    current_dataset = "default" 
    main()