import ml_collections
from dataloader import get_dataset
    
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from train_utils import train_one_epoch, train_one_epoch_fkd, train_one_epoch_gsam, eval_ensemble
from typing import List, Iterator
from modified_resnets import resnet18_silu
from collections import OrderedDict

from wide_resnets import WideResNet

class _MapDatasetFetcher(torch.utils.data._utils.fetch._BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if hasattr(self.dataset, "mode") and self.dataset.mode == 'fkd_load':
            soft_label = self.dataset.load_batch_config(possibly_batched_index[0])[-1]

        if self.auto_collation:
            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
                data = self.dataset.__getitems__(possibly_batched_index)
            else:
                data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]

        if hasattr(self.dataset, "mode") and self.dataset.mode == 'fkd_load':
            return self.collate_fn(data), soft_label.cpu()
        else:
            return self.collate_fn(data)

torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher


import os
import fire


class FKDBatchSampler(torch.utils.data.Sampler[List[int]]):
    def __init__(self, batch_idx2img_list, epoch_order) -> None:
        self.batch_idx2img_list = batch_idx2img_list
        self.epoch = 0
        self.epoch_order = epoch_order
        self.batch_idx = 0

    def __iter__(self) -> Iterator[List[int]]:
        shit = self.batch_idx2img_list[self.epoch_order[self.epoch]]
        
        self.epoch += 1
        self.epoch = self.epoch % len(self.epoch_order)
        # print(shit)
        # return iter([[s] for s in shit])
        for s in shit:
            yield s
        

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        return len(self.batch_idx2img_list[0])

def get_config():
    config = ml_collections.ConfigDict()
    config.random_seed = 0
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = True

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'cifar100'
    config.dataset.data_path = 'data/tensorflow_datasets'

    return config


def main(n_epochs = 5, model_save_paths = ['./trained_models/model.pt'], model_types = ['wrn34-10'], distilled_dir = './distilled_images_test/', real_ds = False, result_save_name = 'ensemble_result', dataset_name = 'tiny_imagenet', 
data_folder = 'slkdjfsdf', fkd_seed = 0, dataset_min_ipc = 0, dataset_ipc = 5000000, batch_size = 128, label_filename = '/', lr = 1e-3, model_type = 'resnet18', max_fkd_epoch = 999999, source_n_per_class = 1000, 
from_scratch = False, width_multiplier = 1.0, activation = 'silu', model_save_path = './trained_models/model.pt/', stage = 0, target_resolution = 224, free_bn = False, save_model = False):
    config = get_config()
    config.dataset.name = dataset_name
    device = 'cuda:0'
    

    print('preparing data')


    (ds_train, ds_test), preprocess_op, rev_preprocess_op, (dataset_mean, dataset_std) = get_dataset(config.dataset, data_folder = data_folder, target_resolution = target_resolution, apply_aug = False)


    print('loading distilled dataset')

    normalize = transforms.Normalize(mean=dataset_mean.reshape(-1),
                                    std=dataset_std.reshape(-1))
    def is_valid_file(path):
        filename = os.path.basename(path)
        img_index = int((filename.split('id')[-1]).split('.')[0])
        return img_index < dataset_ipc + dataset_min_ipc and img_index >= dataset_min_ipc

    # if not real_ds:
    #     ds_train = datasets.ImageFolder(
    #         distilled_dir + '/images/',
    #         transforms.Compose([
    #             transforms.RandomResizedCrop(target_resolution, scale = (0.3, 1.0)),
    #             transforms.ToTensor(),
    #             normalize,
    #     ]), target_transform = transforms.Compose([
    #         lambda y:torch.LongTensor([y]),
    #         transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, config.dataset.num_classes).type(torch.float)[0])
    #     ]),
    #     is_valid_file = is_valid_file)
        
    #     ds_train = torch.utils.data.DataLoader(
    #         ds_train, shuffle = True,
    #         num_workers=16, pin_memory=True, batch_size = batch_size)
    

    if ',' in model_save_paths:
        model_save_paths = model_save_paths.split(',')

    if not(type(model_save_paths) is list or type(model_save_paths) is tuple):
        model_save_paths = [model_save_paths]


    if ',' in model_types:
        model_types = model_types.split(',')

    if not(type(model_types) is list or type(model_types) is tuple):
        model_types = [model_types]

    all_models = []

    activation_fn = nn.SiLU
    
    if activation == 'relu':
        activation_fn = nn.ReLU


    for m_i, model_save_path in enumerate(model_save_paths):
        print("preparing model")
        model_name = model_types[m_i]
        # model = torchvision.models.get_model(model_type, num_classes=source_n_per_class, weights = None)
        if model_name == 'resnet18_silu':
            model = resnet18_silu(num_classes=source_n_per_class if stage == 0 else config.dataset.num_classes, activation = activation_fn)
        elif model_name == 'wrn28-8':
            model = WideResNet(28, source_n_per_class, 8, activation = activation_fn)
        elif model_name == 'wrn28-10':
            model = WideResNet(28, source_n_per_class, 10, activation = activation_fn)
        elif model_name == 'wrn28-4':
            model = WideResNet(28, source_n_per_class, 4, activation = activation_fn)
        elif model_name == 'wrn22-8':
            model = WideResNet(22, source_n_per_class, 8, activation = activation_fn)
        elif model_name == 'wrn16-8':
            model = WideResNet(16, source_n_per_class, 8, activation = activation_fn)
        elif model_name == 'wrn40-4':
            model = WideResNet(40, source_n_per_class, 4, activation = activation_fn)
        elif model_name == 'wrn34-10':
            model = WideResNet(34, source_n_per_class, 10, activation = activation_fn)
        elif model_name == 'wrn10-4':
            model = WideResNet(10, source_n_per_class, 4, activation = activation_fn)



        state_dict = torch.load(model_save_path)
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'module.' == k[:7]:
                name = k[7:] # remove `module.`
            else:
                name = k
            new_state_dict[name] = v    
        model.load_state_dict(new_state_dict)
        model.to(device)
        
        print(f'loaded model at {model_save_path}')

        all_models.append(model)

    
    test_acc = eval_ensemble(all_models, ds_test)[0].cpu().item()

    # print(test_acc)

    
    
    result_path = os.path.expanduser('{}/results/{}/ipc_{}/{}.txt'.format(distilled_dir, label_filename, dataset_ipc, result_save_name))
    

    if not os.path.exists(os.path.dirname(result_path)):
        os.makedirs(os.path.dirname(result_path))
    if os.path.isfile(result_path):
        print("OLD BOY")
        with open(result_path, 'a') as f:
            f.write('\n' + str(test_acc * 100))
    else:
        print("WRITING NEW")
        with open(result_path, 'w') as f:
            f.write('\n' + str(test_acc * 100))
    

if __name__ == '__main__':
    fire.Fire(main)