import ml_collections
from dataloader import get_dataset
    
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from train_utils import train_one_epoch, train_one_epoch_fkd, train_one_epoch_gsam
import torchvision.datasets as datasets
from typing import List, Iterator
from modified_resnets import resnet18_silu
from gsam import GSAM, ProportionScheduler
from collections import OrderedDict

from wide_resnets import WideResNet

class _MapDatasetFetcher(torch.utils.data._utils.fetch._BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if hasattr(self.dataset, "mode") and self.dataset.mode == 'fkd_load':
            soft_label = self.dataset.load_batch_config(possibly_batched_index[0])[-1]

        if self.auto_collation:
            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
                data = self.dataset.__getitems__(possibly_batched_index)
            else:
                data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]

        if hasattr(self.dataset, "mode") and self.dataset.mode == 'fkd_load':
            return self.collate_fn(data), soft_label.cpu()
        else:
            return self.collate_fn(data)

torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher


import os
import fire
import copy


class FKDBatchSampler(torch.utils.data.Sampler[List[int]]):
    def __init__(self, batch_idx2img_list, epoch_order) -> None:
        self.batch_idx2img_list = batch_idx2img_list
        self.epoch = 0
        self.epoch_order = epoch_order
        self.batch_idx = 0

    def __iter__(self) -> Iterator[List[int]]:
        shit = self.batch_idx2img_list[self.epoch_order[self.epoch]]
        
        self.epoch += 1
        self.epoch = self.epoch % len(self.epoch_order)
        # print(shit)
        # return iter([[s] for s in shit])
        for s in shit:
            yield s
        

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        return len(self.batch_idx2img_list[0])

def get_config():
    config = ml_collections.ConfigDict()
    config.random_seed = 0
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = True

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'cifar100'
    config.dataset.data_path = 'data/tensorflow_datasets'

    return config


def main(n_epochs = 5, init_model_save_path = './trained_models/model.pt', distilled_dir = './distilled_images_test/', real_ds = False, result_save_name = 'result', dataset_name = 'tiny_imagenet', 
data_folder = 'slkdjfsdf', dataset_min_ipc = 0, dataset_ipc = 5000000, batch_size = 128, label_filename = '/', lr = 1e-3, model_type = 'resnet18', source_n_per_class = 1000, 
from_scratch = False, activation = 'relu', model_save_path = './trained_models/model.pt/', load_stage = None, model_name = 'resnet18_silu', target_resolution = 32, free_bn = False, save_model = False, stage = 0, multi = False, class_inc = False, gsam = True):
    config = get_config()
    config.dataset.name = dataset_name
    device = 'cuda:0'
    


    print('preparing data')


    (ds_train_orig, ds_test), preprocess_op, rev_preprocess_op, (dataset_mean, dataset_std) = get_dataset(config.dataset, data_folder = data_folder, target_resolution = target_resolution, apply_aug = False)


    print('loading distilled dataset')

    normalize = transforms.Normalize(mean=dataset_mean.reshape(-1),
                                    std=dataset_std.reshape(-1))
    def is_valid_file(path):
        filename = os.path.basename(path)
        img_index = int((filename.split('id')[-1]).split('.')[0])
        return img_index < dataset_ipc + dataset_min_ipc and img_index >= dataset_min_ipc
    
    
    if multi:
        if ',' in distilled_dir:
            distilled_dir = distilled_dir.split(',')

        if not(type(distilled_dir) is list or type(distilled_dir) is tuple):
            distilled_dir = [distilled_dir]
    

    if not real_ds:
        
        
        if not multi:
            
            distilled_dir_suffix = '/images/' if load_stage is None else f'/images/stage_{load_stage}/'
            ds_train = datasets.ImageFolder(
                distilled_dir + distilled_dir_suffix,
                transforms.Compose([
                    transforms.RandomResizedCrop(target_resolution, scale = (0.3, 1.0)),
                    # transforms.Resize(target_resolution),
                    transforms.ToTensor(),
                    normalize,
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, config.dataset.num_classes).type(torch.float)[0])
            ]),
            is_valid_file = is_valid_file)
            
            ds_train = torch.utils.data.DataLoader(
                ds_train, shuffle = True,
                num_workers=16, pin_memory=True, batch_size = batch_size)
        
        else:
            if not class_inc:
                ds_trains = []
                for distilled_dir_sub in distilled_dir:
                    print(distilled_dir_sub)
                    distilled_dir_suffix = '/images/' if load_stage is None else f'/images/stage_{load_stage}/'
                    ds_train = datasets.ImageFolder(
                        distilled_dir_sub + distilled_dir_suffix,
                        transforms.Compose([
                            transforms.RandomResizedCrop(target_resolution, scale = (0.3, 1.0)),
                            transforms.ToTensor(),
                            normalize,
                    ]), target_transform = transforms.Compose([
                        lambda y:torch.LongTensor([y]),
                        transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, config.dataset.num_classes).type(torch.float)[0])
                    ]),
                    is_valid_file = is_valid_file)
                    
                    ds_trains.append(ds_train)
                    
                distilled_dir = distilled_dir[-1]
                    
                ds_train = torch.utils.data.ConcatDataset(ds_trains)
                    
                ds_train = torch.utils.data.DataLoader(
                    ds_train, shuffle = True,
                    num_workers=16, pin_memory=True, batch_size = batch_size)
            
            else:
                ds_trains = []
                
                def get_relabel_fn(ds_i, num_classes):
                    def relabel(y):
                        return torch.nn.functional.one_hot(y + 10 * ((ds_i + 1) - 1), num_classes).type(torch.float)[0]
                    
                    return relabel
                    
                
                for ds_i, distilled_dir_sub in enumerate(distilled_dir):
                    distilled_dir_suffix = '/images/' if load_stage is None else f'/images/stage_{load_stage}/'
                    
                    ds_train = datasets.ImageFolder(
                        distilled_dir_sub + distilled_dir_suffix,
                        transforms.Compose([
                            transforms.RandomResizedCrop(target_resolution, scale = (0.3, 1.0)),
                            transforms.ToTensor(),
                            normalize,
                    ]), target_transform = transforms.Compose([
                        lambda y:torch.LongTensor([y]),
                        get_relabel_fn(ds_i, config.dataset.num_classes)
                    ]),
                    is_valid_file = is_valid_file)
                    
                    ds_trains.append(ds_train)
                    
                distilled_dir = distilled_dir[-1]
                    
                ds_train = torch.utils.data.ConcatDataset(ds_trains)
                    
                ds_train = torch.utils.data.DataLoader(
                    ds_train, shuffle = True,
                    num_workers=16, pin_memory=True, batch_size = batch_size)
            
            
        
                
    activation_fn = nn.SiLU
    
    if activation == 'relu':
        activation_fn = nn.ReLU

    print("preparing model")
    # model = torchvision.models.get_model(model_type, num_classes=source_n_per_class, weights = None)
    if model_name == 'resnet18_silu':
        model = resnet18_silu(num_classes=source_n_per_class if stage == 0 else config.dataset.num_classes, activation = activation_fn)
    elif model_name == 'wrn28-8':
        model = WideResNet(28, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn28-10':
        model = WideResNet(28, source_n_per_class, 10, activation = activation_fn)
    elif model_name == 'wrn28-4':
        model = WideResNet(28, source_n_per_class, 4, activation = activation_fn)
    elif model_name == 'wrn22-8':
        model = WideResNet(22, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn16-8':
        model = WideResNet(16, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn40-4':
        model = WideResNet(40, source_n_per_class, 4, activation = activation_fn)
    elif model_name == 'wrn34-10':
        model = WideResNet(34, source_n_per_class, 10, activation = activation_fn)
    elif model_name == 'wrn10-4':
        model = WideResNet(10, source_n_per_class, 4, activation = activation_fn)

    # model = torchvision.models.get_model('resnet101', num_classes=source_n_per_class, pretrained = True)
    if dataset_name == 'tiny_imagenet':
        model.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
        model.maxpool = nn.Identity()
    
    if not from_scratch:
        if len(init_model_save_path) > 0:

            state_dict = torch.load(init_model_save_path)
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                if 'module.' == k[:7]:
                    name = k[7:] # remove `module.`
                else:
                    name = k
                new_state_dict[name] = v    
            model.load_state_dict(new_state_dict)
            print("LOADING PATH")
    
    
    if stage == 0:
        fc_in_dim = model.fc.weight.shape[1]
        model.fc = nn.Linear(fc_in_dim, config.dataset.num_classes)
    
    
    model.to(device)


    
    if dataset_name == 'tiny_imagenet':
        base_optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay = 1e-4)
    else:
        # optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay = 1e-2)
        base_optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay = 0)
    # cosine_lr_schedule = optim.lr_scheduler.CosineAnnealingLR(base_optimizer, n_epochs, eta_min=0.00)

    # cosine_lr_schedule = optim.lr_scheduler.CosineAnnealingLR(base_optimizer, n_epochs, eta_min=0.00)
    cosine_lr_schedule = optim.lr_scheduler.LambdaLR(base_optimizer, lambda x: 1)
    
#     rho_scheduler = ProportionScheduler(pytorch_lr_scheduler=cosine_lr_schedule, max_lr=lr, min_lr=0.0,
#  max_value=0.05, min_value=0.02)
    
#     optimizer = GSAM(params=model.parameters(), base_optimizer=base_optimizer, model=model, adaptive=False, gsam_alpha = 0.01, rho_scheduler = rho_scheduler)


    rho_scheduler = ProportionScheduler(pytorch_lr_scheduler=cosine_lr_schedule, max_lr=lr, min_lr=0.0,
 max_value=0.05, min_value=0.02) ####################THIS IS THE CONFIG USED FOR ALL EXPS
    
#     rho_scheduler = ProportionScheduler(pytorch_lr_scheduler=cosine_lr_schedule, max_lr=lr, min_lr=0.0,
#  max_value=0.03, min_value=0.02)
    
    optimizer = GSAM(params=model.parameters(), base_optimizer=base_optimizer, model=model, adaptive=False, gsam_alpha = 0.04, rho_scheduler = rho_scheduler)

    if not gsam:
        optimizer = base_optimizer

    lr_schedule = cosine_lr_schedule

    
    print(f"USING MODEL {model_type}")
    print(f"LR {lr}")

    print(f"{not free_bn}")

    print('training')
    for i in range(n_epochs):
        if gsam:
            train_loss, train_acc, _ = train_one_epoch_gsam(model, ds_train, optimizer, lr_schedule, preprocess_op, freeze_bn = not free_bn)
        else:
            train_loss, train_acc, _ = train_one_epoch(model, ds_train, optimizer, lr_schedule, preprocess_op, freeze_bn = not free_bn)
        print(train_loss)
        
        if i%1 == 0:
            test_loss, test_acc, _ = train_one_epoch(model, ds_test, optimizer, lr_schedule, preprocess_op, train = False)
            print(f'epoch: {i + 1}, test_acc: {test_acc}')

    test_loss, test_acc, _ = train_one_epoch(model, ds_test, optimizer, lr_schedule, preprocess_op, train = False)
    print(f'epoch: {i + 1}, test_acc: {test_acc}')
        

    print('Done training!')

    result_path = os.path.expanduser('{}/results/{}/ipc_{}/{}.txt'.format(distilled_dir, label_filename, dataset_ipc, result_save_name))
    

    if not os.path.exists(os.path.dirname(result_path)):
        os.makedirs(os.path.dirname(result_path))
    
    if save_model:
        if not os.path.exists(os.path.dirname(model_save_path)):
            os.makedirs(os.path.dirname(model_save_path))
            
        torch.save(model.state_dict(), model_save_path)

    if os.path.isfile(result_path):
        print("OLD BOY")
        with open(result_path, 'a') as f:
            f.write('\n' + str((test_acc * 100).cpu().item()))
    else:
        print("WRITING NEW")
        with open(result_path, 'w') as f:
            f.write('\n' + str((test_acc * 100).cpu().item()))

if __name__ == '__main__':
    fire.Fire(main)