import json
import os
from abc import abstractmethod, ABC

import numpy as np
import torch
from torch.utils.data import Dataset


class ParamDataset(Dataset, ABC):
    def __init__(self, param_dir, param_names=None, transform=None):
        self.param_dir = param_dir
        self.transform = transform
        self.param_names = param_names

    def _load_data(self, param_path):  
        param_data = torch.load(param_path, weights_only=True)
        return param_data

    @abstractmethod
    def _make_dataset(self):
        raise NotImplementedError()

class PairedParamDataset(ParamDataset):
    def __init__(self, param_root_dir, src_tgt_dict, param_names=None, transform=None):
        """
        :param param_root_dir: The root path of the parameter data
        :param transform: transformation for the parameter data
        """
        super().__init__(param_root_dir, param_names, transform=transform)
        self.samples = self._make_dataset(src_tgt_dict)

    def _make_dataset(self, src_tgt_dict):
        samples = []  # src_param_path, tgt_param_path, src_param_type, tgt_param_type, param_name

        src_param_dir = os.path.join(self.param_dir, 'src')
        src_param_types = [d.name for d in os.scandir(src_param_dir) if d.is_dir()]  # e.g., cifar10-lt-10/
        tgt_param_dir = os.path.join(self.param_dir, 'tgt')
        tgt_param_types = [d.name for d in os.scandir(tgt_param_dir) if d.is_dir()]  # e.g., cifar10/
        
        source_to_target = {}
        for target, sources in src_tgt_dict.items():
            for source in sources:
                source_to_target[source] = target


        for src_param_type in src_param_types:  # e.g., cifar10-lt-10/
            src_param_type_dir = os.path.join(src_param_dir, src_param_type)
            param_info = json.load(open(os.path.join(src_param_type_dir, 'param_info.json')))
            model_classnum = param_info['model_classnum']
            
            param_names = self.param_names if self.param_names is not None else os.listdir(src_param_type_dir)
            for param_name in param_names:  # e.g., layer1.0.bn1.weight/
                    src_param_name_dir = os.path.join(src_param_type_dir, param_name)
                    if not os.path.isdir(src_param_name_dir):  # Avoiding files, e.g., param_info.json
                        continue

                    for param_file in os.listdir(src_param_name_dir):  # e.g., model_cont_0.pt
                        src_param_path = os.path.join(src_param_dir, src_param_type, param_name, param_file)

                        assert src_param_type in source_to_target, f"Source type {src_param_type} not found in the source-target dictionary"
                        tgt_param_type = source_to_target[src_param_type]
                        tgt_param_path = os.path.join(tgt_param_dir, tgt_param_type, param_name, param_file)
                        samples.append([src_param_path, tgt_param_path, src_param_type, tgt_param_type, param_name, model_classnum])

        return samples

    def __getitem__(self, idx):
        src_param_path, tgt_param_path, src_param_type, tgt_param_type, param_name, model_classnum = self.samples[idx]
        src_param = self._load_data(src_param_path)
        tgt_param = self._load_data(tgt_param_path)

        if self.transform is not None:
            src_param = self.transform(src_param)
            tgt_param = self.transform(tgt_param)

        return src_param, tgt_param, src_param_type, tgt_param_type, param_name, model_classnum

    def __len__(self):
        return len(self.samples)


class UnpairedParamDataset(ParamDataset):
    def __init__(self, param_root_dir, param_names=None, transform=None, types=None):
        """
        :param param_root_dir: The root path of the parameter data
        :param param_layer_names: The layer names of the parameter data
        :param transform:
        """
        # print('-' * 50)
        # print('PARAM ROOT DIR', param_root_dir)
        # print('PARAM NAMES', param_names)
        # print('PARAM TYPES', types)
        # print('-' * 50)
        
        super().__init__(param_root_dir, param_names, transform=transform)
        self.types = ['src', 'tgt', 'delta'] if types is None else types
        self.samples = self._make_dataset()
        

    def _make_dataset(self):
        samples = [] 

        for type in self.types:
            param_dir = os.path.join(self.param_dir, type)
            if os.path.exists(param_dir):
                param_types = [d.name for d in os.scandir(param_dir) if d.is_dir()]  # e.g., cifar10-lt-10/
                for param_type in param_types:  # e.g., cifar10-lt-10/
                    param_type_dir = os.path.join(param_dir, param_type)
                    param_names = self.param_names if self.param_names is not None else os.listdir(param_type_dir)
                    
                    param_info = json.load(open(os.path.join(param_type_dir, 'param_info.json')))
                    param_var_info = param_info['var_info']
                    model_classnum = param_info['model_classnum']
                    
                    for param_name in param_names:  # e.g., layer1.0.bn1.weight/
                        param_name_dir = os.path.join(param_type_dir, param_name)
                        if not os.path.isdir(param_name_dir):  # Avoiding files, e.g., param_info.json
                            continue
                        for param_file in os.listdir(param_name_dir):  # e.g., model_cont_0.pt
                            param_path = os.path.join(param_name_dir, param_file)
                            samples.append([param_path, param_type, param_name, param_var_info[param_file], model_classnum])

        return samples
        
    def __getitem__(self, idx):
        param_path, param_type, param_name, param_var_info, model_classnum = self.samples[idx]
        param_data = self._load_data(param_path)
        
        if self.transform is not None:
            param_data = self.transform(param_data)
        
        return param_data, param_type, param_name, param_var_info, model_classnum

    def __len__(self):
        return len(self.samples)