
import os
import copy
from dataclasses import dataclass
import json
from typing import Dict, Sequence
import numpy as np

import torch
import evaluate
import transformers

from llava.constants import IGNORE_INDEX
from llava.training_module.preprocess import preprocess, preprocess_multimodal
from llava.training_module.load_args import DataArguments
from llava.training_module.utils import rank0_print
from torch.utils.data import Dataset

from PIL import Image


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments):
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(data_path, "r"))

        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_file = self.list_data_dict[i]['image']
            image_folder = self.data_args.image_folder
            processor = self.data_args.image_processor
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
        # adjust to conv pattern
        from llava.conversation import conv_templates
        conv_template = conv_templates[self.data_args.version]
        data_dict = preprocess(
            conv_template,
            sources,
            self.tokenizer,
            has_image=('image' in self.list_data_dict[i]))
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])

        # image exist in the data
        if 'image' in self.list_data_dict[i]:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict


class ClassificationSupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, label_dict_path: str, image_folder: str,
                 data_args: DataArguments):
        super(ClassificationSupervisedDataset, self).__init__()
        with open(label_dict_path, "r") as f:
            label_dict = json.load(f)
        
        self.label_dict = label_dict
        self.data_args = data_args
        self.image_folder = image_folder

        file_path = os.path.dirname(label_dict_path)
        category_dict = json.load(open(os.path.join(file_path, 'category_dict.json'), "r"))
        self.all_classes = list(category_dict.values())


    def __len__(self):
        return len(self.label_dict) # 118287 train, 5000 val

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        
        keys = list(self.label_dict.keys())

        image_file = keys[i]
        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        label_ls = self.label_dict[keys[i]]
        
        processor = self.data_args.image_processor

        if self.data_args.image_aspect_ratio == 'pad':
            def expand2square(pil_img, background_color):
                width, height = pil_img.size
                if width == height:
                    return pil_img
                elif width > height:
                    result = Image.new(pil_img.mode, (width, width), background_color)
                    result.paste(pil_img, (0, (width - height) // 2))
                    return result
                else:
                    result = Image.new(pil_img.mode, (height, height), background_color)
                    result.paste(pil_img, ((height - width) // 2, 0))
                    return result
            image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))

        image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        
        # one-hot encoding
        label = torch.zeros(len(self.all_classes))
        for l in label_ls:
            label[self.all_classes.index(l)] = 1
        
        data_dict = dict(
            images=image,
            labels=label
        )
        return data_dict

    def get_predict_classes(self, predict_labels):
        if isinstance(predict_labels, torch.Tensor):
            predict_labels = predict_labels.cpu().numpy()
        # reshape to 1D array
        predict_labels = predict_labels.reshape(-1)
        return [self.all_classes[i] for i, value in enumerate(predict_labels) if value == 1]


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))

        pad_token_id = 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=pad_token_id) # TODO: self.tokenizer.pad_token_id is None, but pad_token_id can be found in config
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(pad_token_id), # TODO: self.tokenizer.pad_token_id is None, but pad_token_id can be found in config
        )

        if 'image' in instances[0]:
            images = [instance['image'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch['images'] = torch.stack(images)
            else:
                batch['images'] = images

        return batch
    

def make_supervised_data_module_classification(data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = ClassificationSupervisedDataset(label_dict_path=data_args.data_path, image_folder=data_args.image_folder,
                                data_args=data_args)
    eval_dataset = ClassificationSupervisedDataset(label_dict_path=data_args.data_path.replace('train', 'val'), image_folder=data_args.image_folder.replace('train', 'val'),
                                data_args=data_args)
  
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=collate_batch,
                compute_metrics=compute_metrics_manual)

def make_supervised_data_module_pretrain(data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    from datasets import load_dataset

    dataset = load_dataset("liuhaotian/LLaVA-CC3M-Pretrain-595K")

    train_dataset = dataset['train']
    eval_dataset = dataset['validation']
  
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=collate_batch,
                compute_metrics=compute_metrics)


def collate_batch(batch):
    input = torch.stack([item['images'] for item in batch])
    target = torch.stack([item['labels'] for item in batch])

    return {"images": input, "labels": target}
    

def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


def compute_metrics_manual(pred):
    labels = pred.label_ids
    predictions = pred.predictions
    if isinstance(predictions, np.ndarray):
        # Handle the case where predictions is a standalone array
        preds = predictions
    elif isinstance(predictions, tuple):
        preds = predictions[0]
    else:
        preds = predictions.detach().cpu().numpy()

    if isinstance(labels, np.ndarray):
        # Handle the case where labels is a standalone array
        labels = labels
    elif isinstance(labels, tuple):
        labels = labels[0]
    else:
        labels = labels.detach().cpu().numpy()


    th = 0.3
    preds = (preds > th).astype(int)

    return calculate_metrics(preds, labels)
    

def calculate_metrics(predictions, true_labels):
    
    task_precisions = np.zeros(predictions.shape[1])
    task_recalls = np.zeros(predictions.shape[1])
    task_accuracies = np.zeros(predictions.shape[1])
    task_f1_scores = np.zeros(predictions.shape[1])

    for i in range(predictions.shape[1]):
        true_positive = np.sum((predictions[:, i] == 1) & (true_labels[:, i] == 1))
        false_positive = np.sum((predictions[:, i] == 1) & (true_labels[:, i] == 0))
        false_negative = np.sum((predictions[:, i] == 0) & (true_labels[:, i] == 1))

        if true_positive + false_positive == 0:
            task_precisions[i] = 0.0
        else:
            task_precisions[i] = true_positive / (true_positive + false_positive)

        if true_positive + false_negative == 0:
            task_recalls[i] = 0.0
        else:
            task_recalls[i] = true_positive / (true_positive + false_negative)

        task_accuracies[i] = np.mean(predictions[:, i] == true_labels[:, i])

        if task_precisions[i] + task_recalls[i] == 0:
            task_f1_scores[i] = 0.0
        else:
            task_f1_scores[i] = 2 * (task_precisions[i] * task_recalls[i]) / (task_precisions[i] + task_recalls[i])

    overall_precision = np.mean(task_precisions)
    overall_recall = np.mean(task_recalls)
    overall_accuracy = np.mean(task_accuracies)
    overall_f1_score = np.mean(task_f1_scores)

    image_precisions = np.zeros(predictions.shape[0])
    image_recalls = np.zeros(predictions.shape[0])
    image_accuracies = np.zeros(predictions.shape[0])
    image_f1_scores = np.zeros(predictions.shape[0])

    for i in range(predictions.shape[0]):
        true_positive = np.sum((predictions[i, :] == 1) & (true_labels[i, :] == 1))
        false_positive = np.sum((predictions[i, :] == 1) & (true_labels[i, :] == 0))
        false_negative = np.sum((predictions[i, :] == 0) & (true_labels[i, :] == 1))

        if true_positive + false_positive == 0:
            image_precisions[i] = 0.0
        else:
            image_precisions[i] = true_positive / (true_positive + false_positive)

        if true_positive + false_negative == 0:
            image_recalls[i] = 0.0
        else:
            image_recalls[i] = true_positive / (true_positive + false_negative)

        image_accuracies[i] = np.mean(predictions[i, :] == true_labels[i, :])

        if image_precisions[i] + image_recalls[i] == 0:
            image_f1_scores[i] = 0.0
        else:
            image_f1_scores[i] = 2 * (image_precisions[i] * image_recalls[i]) / (image_precisions[i] + image_recalls[i])
    results_dict = {
        'overall_level': {
            'precision': round(overall_precision, 2),
            'recall': round(overall_recall, 2),
            'accuracy': round(overall_accuracy, 2),
            'f1_score': round(overall_f1_score, 2)
        },
        'task_level': {
            'precisions': np.round(task_precisions, 2).tolist(),
            'recalls': np.round(task_recalls, 2).tolist(),
            'accuracies': np.round(task_accuracies, 2).tolist(),
            'f1_scores': np.round(task_f1_scores, 2).tolist()
        },
        'image_level': {
            'precisions': np.round(image_precisions, 2).tolist(),
            'recalls': np.round(image_recalls, 2).tolist(),
            'accuracies': np.round(image_accuracies, 2).tolist(),
            'f1_scores': np.round(image_f1_scores, 2).tolist()
        }
    }
    return results_dict 


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
                                data_path=data_args.data_path,
                                data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)

if __name__ == "__main__":
    # data_path = "./data/pretrain/LLaVA-CC3M-Pretrain-595K/chat.json"
    # image_folder = "./data/pretrain/LLaVA-CC3M-Pretrain-595K/images"

    images = torch.zeros(16, 1, 800, 800)
    image_features = torch.zeros(16, 99, 4096)
    split_sizes = [image.shape[0] for image in images]
    print(split_sizes)
    image_features = torch.split(image_features, split_sizes, dim=0)
    print(len(image_features))
    print(image_features[0].shape)
    image_features = [x.flatten(0, 1) for x in image_features]
    print(image_features[0].shape)


    # data_path = "./playground/data/llava_instruct_80k.json"
    # image_folder = "./data/train2017"
    # tokenizer = transformers.AutoTokenizer.from_pretrained(
    #     "./checkpoints/llava-llama-2-7b-chat-lightning-preview",
    #     model_max_length=2048,
    #     padding_side="right",
    #     use_fast=False
    #     )
    # image_processor = transformers.AutoImageProcessor.from_pretrained(
    #     "facebook/detr-resnet-50",
    #     cache_dir=None,
    #     use_fast=False
    # )
    # data_args = DataArguments(data_path=data_path, image_folder=image_folder, image_aspect_ratio='pad')
    # data_args.image_processor = image_processor
    # data_module = make_supervised_data_module(tokenizer, data_args)   
    # train_dataset = data_module['train_dataset']
    # train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=5, shuffle=False, num_workers=0)
    # # load several samples from train_dataset
    # for i, data in enumerate(train_dataloader):
    #     print(data.keys())
    #     images, labels = data['image'], data['labels'] # dict_keys(['input_ids', 'labels', 'image'])
    #     print(f"images shape: {images.shape}")
    #     print(f"labels shape: {labels.shape}")
        # images shape: torch.Size([3, 800, 800])
        # labels shape: torch.Size([77])