from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict, List
import pickle
import json
import sys
import numpy as np
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import torch
import pickle
import os


PROJECT_DIR = Path(__file__).parent.parent.absolute()

sys.path.append(PROJECT_DIR.as_posix())
sys.path.append(PROJECT_DIR.joinpath("datapreprocess").as_posix())

from datapreprocess.config import DATASET_TO_CONFIG
from datapreprocess.extractor import CLIPModelExtractor
from data.utils.datasets import DATASETS
from data.utils.constants import MEAN, STD

def get_processor_argparser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument(
        "-d",
        "--dataset",
        type=str,
        choices=[
            "mnist",
            "cifar10",
            "cifar100",
            "synthetic",
            "femnist",
            "emnist",
            "fmnist",
            "celeba",
            "medmnistS",
            "medmnistA",
            "medmnistC",
            "covid19",
            "svhn",
            "usps",
            "tiny_imagenet",
            "cinic10",
            "domain",
        ],
        default="cifar10",
    )
    return parser


class Processor:
    """
    This class is responsible for processing the data
    """

    def __init__(self):
        # get dataset information
        self.args = get_processor_argparser().parse_args()
        self.labels_texts = DATASET_TO_CONFIG[self.args.dataset]

        # initialize data
        self.initilize_data()

        # load dataset
        self.load_dataset()

        # initialize foundation model
        self.extractor = CLIPModelExtractor()

        # get text embeddings
        labels_descriptions = ['a photo of ' + self.labels_texts[i] for i in range(len(self.labels_texts.keys()))]
        self.text_embeddings = self.extractor.encode_label(labels_descriptions)
        print(labels_descriptions)


    def initilize_data(self):

        with open(PROJECT_DIR / "data" / self.args.dataset / "args.json", "r") as f:
            self.args.dataset_args = json.load(f)

        # get client party info
        try:
            partition_path = PROJECT_DIR / "data" / self.args.dataset / "partition.pkl"
            with open(partition_path, "rb") as f:
                partition = pickle.load(f)
        except:
            raise FileNotFoundError(f"Please partition {self.args.dataset} first.")
        self.train_clients: List[int] = partition["separation"]["train"]
        self.test_clients: List[int] = partition["separation"]["test"]
        self.client_num: int = partition["separation"]["total"]
        self.clients_indices = partition["data_indices"]
        print(self.clients_indices[0])

    def load_dataset(self):
        # general_data_transform = transforms.Compose(
        #     [transforms.Normalize(MEAN[self.args.dataset], STD[self.args.dataset])]
        # )
        general_data_transform = transforms.Compose(
            []
        )
        self.dataset = DATASETS[self.args.dataset](
            root=PROJECT_DIR / "data" / self.args.dataset,
            args=self.args.dataset_args,
            general_data_transform=general_data_transform
        )

    def process_client_input_features(self, indicies):
        """
        This method is responsible for processing the input data of clients
        :return:
        """
        current_set = Subset(self.dataset, indicies)
        images = [current_set[i][0] for i in range(len(current_set))]
        features = self.extractor.encode_image(images)
        return features
    
    def process_client_label_features(self, indicies):
        
        current_set = Subset(self.dataset, indicies)
        features = torch.zeros((len(current_set), 512))
        for i in range(len(current_set)):
            features[i] = self.text_embeddings[current_set[i][1]]
        return features
    
       



if __name__ == "__main__":
    t, w = 0, 0
    processor = Processor()
    dataset_root = './features/{}'.format(processor.args.dataset)
    if not os.path.isdir(dataset_root):
        os.makedirs(dataset_root)

    print(processor.train_clients)
    print(processor.test_clients)
    
    for client in processor.train_clients:
        client_indices = processor.clients_indices[client]['train']
        image_features = processor.process_client_input_features(client_indices)
        label_features = processor.process_client_label_features(client_indices)
        extracted = {
            "image_features": image_features,
            "label_features": label_features
        }
        print(image_features[0].shape)
        print(label_features[0].shape)
        with open(dataset_root + "/extracted-{}.pkl".format(client), "wb") as f:
            pickle.dump(extracted, f)
