from dataclasses import dataclass, field
from typing import Dict, Optional
from model import CoIForGPTNeoXCausalLM
from datasets import load_dataset
import os
import torch
import transformers
from torch.utils.data import Dataset
import logging
import datetime

IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<image>"

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    vision_tower_name: Optional[str] = field(default="facebook/opt-125m")
    add_image_token: bool = field(default=True,)


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    use_patch: bool = field(default=False, metadata={"help": "Use image patch (256) or class (1)"})
    model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."})
    bits: int = field(default=16,metadata={"help": "How many bits to use."})
    double_quant: bool = field(default=True,metadata={"help": "Compress the quantization statistics through double quantization."})
    quant_type: str = field(default="nf4",metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."})
    lora_enable: bool = False
    lora_r: int = 16
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"

class EvalDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
        super(EvalDataset, self).__init__()
        self.list_data_dict = load_dataset('json', data_files={'train': data_args.data_path})['train'].shuffle(seed=42)
        self.tokenizer = tokenizer
        self.data_args = data_args
    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        conversation = sources['conversations'][0]['value'] + "{Convert the board to FEN format"
        input_ids = self.tokenizer(conversation, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length).input_ids[0]
        data_dict = dict(input_ids=input_ids, labels=sources['conversations'][1]['value'],)
        return data_dict
    
def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()  
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d_%H-%M-%S")
    if not os.path.exists('logs'):
        os.mkdir('logs')
    logging.basicConfig(filename=os.path.join('logs',f"{formatted_time}.log"), level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    
    model = CoIForGPTNeoXCausalLM(model_args, training_args)
    ckpt = torch.load(os.path.join(training_args.output_dir, 'pytorch_model.bin'), 'cpu')
    print(model.load_state_dict(ckpt, strict=False))
    eval_dataset = EvalDataset(model.tokenizer, data_args)
    acc=0
    for i, data in enumerate(eval_dataset):
        with torch.no_grad():
            output_text, output_image = model.generate(data["input_ids"][None].to(training_args.device), max_output_len=training_args.model_max_length)
        logging.info(f"ID: {i}")
        logging.info(f"Target: {data['labels']}")
        logging.info(f"Predict: {output_text}")
        logging.info("")
        print(output_text)
        if data['labels'] == output_text.split('#')[0].split()[-1]+'#':
            acc+=1
    logging.info(f"acc: {acc/len(eval_dataset)}")
    logging.info(f"data_path: {data_args.data_path}")
    logging.info(f"output_dir: {training_args.output_dir}")
    

if __name__ == "__main__":
    train()