import copy
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
from transformers import Trainer
from model import CoILlamaForCausalLM
from datasets import load_dataset
from PIL import Image
from torchvision import transforms
import os
import torch
import transformers
from torch.utils.data import Dataset
import cairosvg
from PIL import Image
import io

IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<!--Image-->"

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    vision_tower_name: Optional[str] = field(default="facebook/opt-125m")
    load_checkpoint: Optional[str] = field(default="facebook/opt-125m")
    add_image_token: bool = field(default=True,)


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    use_patch: bool = field(default=False, metadata={"help": "Use image patch (256) or class (1)"})
    model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."})
    bits: int = field(default=16,metadata={"help": "How many bits to use."})
    double_quant: bool = field(default=True,metadata={"help": "Compress the quantization statistics through double quantization."})
    quant_type: str = field(default="nf4",metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."})
    lora_enable: bool = False
    lora_r: int = 16
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"

class TrainingDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, use_patch):
        super(TrainingDataset, self).__init__()
        self.list_data_dict = load_dataset('json', data_files={'train': data_args.data_path + '/svg_*.jsonl'})['train'].shuffle(seed=42)
        self.tokenizer = tokenizer
        self.data_args = data_args
        self.use_patch = use_patch
        self.__getitem__(0)

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        Question = sources['conversations'][0]['value']
        Answer = sources['conversations'][1]['value']
        svg_string=''
        images = []
        svg_spice =  Answer.split(DEFAULT_IMAGE_TOKEN)
        for i, svg in enumerate(svg_spice[:-1]):
            svg_string += svg
            svg_data = io.BytesIO(cairosvg.svg2png(bytestring=svg_string+('</g></svg>')))
            images.append(transforms.ToTensor()(Image.open(svg_data).convert('RGB')))
        images = torch.stack(images)
        if self.use_patch:
            Answer = Answer.replace(DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN*256)
        input_ids = self.tokenizer(Question+" "+Answer+self.tokenizer.eos_token, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length).input_ids[0]
        targets = copy.deepcopy(input_ids)
        qs_len = self.tokenizer(Question, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length).input_ids.shape[1]
        targets[:qs_len]=IGNORE_INDEX
        
        data_dict = dict(input_ids=input_ids,
                        labels=targets,
                        images = images,
                        )
        return data_dict
    
@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels, images = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "images"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        images = torch.cat(images)
        return dict(
            input_ids=input_ids,
            labels=labels,
            images = images,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )
    
def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()  
    model = CoILlamaForCausalLM(model_args, training_args)
    ckpt = torch.load(os.path.join(model_args.load_checkpoint, 'pytorch_model.bin'), 'cpu')
    print(model.load_state_dict(ckpt, strict=False))
    
    train_dataset = TrainingDataset(model.tokenizer, data_args, training_args.use_patch)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=model.tokenizer)
    trainer = Trainer(
        model=model,
        tokenizer=model.tokenizer,
        args=training_args,
        train_dataset = train_dataset,
        data_collator = data_collator,
    )
    trainer.train()
    trainer.save_state()
    if training_args.lora_enable:
        model = model.merge_and_unload()
    print(model.llm)
    print("++++++++++++++++++++++++++++++++++++++++++++++")
    print(model.encoder)
    trainer.save_model(output_dir=training_args.output_dir)
    
if __name__ == "__main__":
    train()