import os
from dataclasses import dataclass, field
from typing import Optional
import os
import gc
import torch
import transformers
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import json
from datasets import Dataset
from sklearn.utils.class_weight import compute_class_weight
from transformers import HfArgumentParser, TrainingArguments, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from sklearn.model_selection import KFold
from sklearn.linear_model import LinearRegression
from datasets import load_dataset
from utils import get_dataset
from base_dataset import BaseData, KshotDataset, make_data_module  
from mytrainer import MyTrainer  
from peft import LoraConfig, TaskType, prepare_model_for_kbit_training, get_peft_model, PrefixTuningConfig, PromptTuningConfig, PromptTuningInit
# from mytrainer import MyTrainer
from pprint import pprint
from constants import *

@dataclass
class Args:
    model_name_or_path: str = field(default=MODEL_IDENTIFIER)
    loss_type: str = field(default='distance')
    use_sparse_attention: bool = field(default=False)
    num_test: bool = field(default=False)
    parameter_efficient_mode: str = field(default='lora+prompt-tuning')
    lora_module: str = field(default='mlp')
    debug_mode: bool = field(default=True)
    #method args
    use_demo: bool = field(default=False)
    n_splits: int = field(default=5)
    method: str = field(default='actcab')
    #data args
    raw_data_path: str = field(default=DATASET_PATH)
    dataset: str = field(default='mathqa')
    # processed_file: str = field(default='1_1_train.json', metadata={"help": "处理后的数据路径"})
    exp_name: str = field(default='baseline') 
    task: str = field(default='probe')

@dataclass
class CustomTrainingArgs(transformers.TrainingArguments):
    per_device_train_batch_size: Optional[int] = field(default=8)
    per_device_eval_batch_size: Optional[int] = field(default=8)
    cache_dir: Optional[str] = field(default=None)
    output_dir: Optional[str] = field(default='')
    overwrite_output_dir: Optional[bool] = field(default=True)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=MAX_INPUT_LENGTH)
    remove_unused_columns: bool = field(default=False)
    save_strategy: str = field(default="no")
    num_train_epochs: Optional[int] = field(default=3)
    learning_rate: Optional[float] = field(default=1e-5,)
    lr_scheduler_type: Optional[str] = field(default='linear')
    warmup_ratio: Optional[float] = field(default=0.1)
    weight_decay: Optional[float] = field(default=0.001)
    logging_steps: Optional[int] = field(default=1)
    report_to: Optional[str] = field(default='wandb')

class LRModel(nn.Module):
    def __init__(self, input_dim):
        super(LRModel, self).__init__()
        self.linear = nn.Linear(input_dim, 2) 

    def forward(self, x):
        return self.linear(x)
    
def main(args, training_args):
    model_name = INDENTIFIER2NAME[args.model_name_or_path]
    hidden_size = AutoConfig.from_pretrained(args.model_name_or_path).hidden_size
    tokenizer = AutoTokenizer.from_pretrained(
                args.model_name_or_path,
                max_len=training_args.model_max_length,
                # cache_dir=args.cache_dir,
            )
    tokenizer.model_max_length = training_args.model_max_length

    special_tokens_dict = dict()
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = '</s>'
    if tokenizer.bos_token is None:
        tokenizer.bos_token = tokenizer.eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"
    kf = KFold(n_splits=args.n_splits, shuffle=True, random_state=42)
    
    if args.dataset=='gsm8k': 
        dataset_path = os.path.join(args.raw_data_path, args.dataset, model_name, '')
        dataset_handler = get_dataset(args.dataset)
        with open(dataset_path, 'r') as f:
            train_data = json.load(f)['results']
        
        train_data = Dataset.from_list(train_data)
    else:
        dataset_path = os.path.join(args.raw_data_path, args.dataset, model_name, '')
        dataset_handler = get_dataset(args.dataset)
        train_data = pd.read_csv(dataset_path)[:5000]
        train_data = Dataset.from_pandas(train_data)
        print(train_data[0])
    


    if args.task == 'tagging':
        ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map="auto")
        for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)):
            print(f'=========Fold {fold}============')
            if args.debug_mode == True:
                train_dataset  = BaseData(args, train_data.select([0,1,2,3,4,5,6,7,8]), soft_prompt_text=None, split='train', dataset_handler=dataset_handler)
            else:
                train_dataset = BaseData(args, train_data.select(train_idx), soft_prompt_text=None, split='train', dataset_handler=dataset_handler)
            val_dataset = BaseData(args, train_data.select(val_idx), soft_prompt_text=None, split='eval', dataset_handler=dataset_handler)
            print(f'traindata:{train_dataset[0]}\nvaldata:{val_dataset[0]}')
            data_module = make_data_module(tokenizer, train_dataset, val_dataset)
            # print(f'train: {data_module['train_dataset'][0]}\neval: {data_module['eval_dataset'][0]}')
            loss_weights = torch.FloatTensor(
                compute_class_weight(class_weight="balanced", classes=np.array([0, 1]), y=train_dataset.y)
            )
            model = LRModel(hidden_size)
            trainer = MyTrainer(custom_args=args, 
                                ref_model = ref_model, 
                                loss_weights=loss_weights,
                                fold=fold,
                                model=model, 
                                tokenizer=tokenizer, 
                                args=training_args, 
                                **data_module)
            trainer.train()
            eval_result = trainer.evaluate()
            pprint(eval_result)
            del trainer
            del model
            gc.collect()
    elif args.task == 'activation':
        train_dataset = None
        val_dataset = BaseData(args, train_data, dataset_handler, split='val')
        data_module = make_data_module(tokenizer, train_dataset, val_dataset)
        trainer = MyTrainer(custom_args=args, 
                                ref_model=None, 
                                model=ref_model, 
                                tokenizer=tokenizer, 
                                args=training_args, 
                                **data_module)
        trainer.evaluate()
    
    elif args.task == 'probe':
        data_path = ''
        with open(data_path, 'rb') as f:
            train_data = pickle.load(f)
        train_data = Dataset.from_dict(train_data)
        print(train_data[0])

    else:
        raise NotImplementedError

    print('Done!')


if __name__ == "__main__":
    parser = HfArgumentParser((Args, CustomTrainingArgs))
    args, training_args = parser.parse_args_into_dataclasses()
    main(args, training_args)



