import ml_collections
from dataloader import get_dataset, configure_dataloader
    
import torch
import torch.nn as nn
import torch.optim as optim
# from tiny_imagenet_dataset import TinyImageNet
import torchvision.transforms as transforms
# from random_crop_fix import RandomResizedCropFixed
from train_utils import train_one_epoch
from modified_resnets import resnet18_silu
from wide_resnets import WideResNet

import os
import fire

def get_config():
    config = ml_collections.ConfigDict()
    config.random_seed = 0
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = True

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'cifar100'
    config.dataset.data_path = 'data/tensorflow_datasets'

    return config


def main(seed = 0, model_save_path = './trained_models/model.pt', n_epochs = 5, dataset_name = 'tiny_imagenet', test_every = 1, extra_checkpoints = [], data_path = '~/Documents/datasets/tiny_imagenet', batch_size = 256, model_name = 'resnet18_silu', activation = 'silu'):
    config = get_config()
    config.dataset.name = dataset_name
    device = 'cuda:0'
    image_size = 64 if dataset_name == 'tiny_imagenet' else 224

    if isinstance(extra_checkpoints, int):
        extra_checkpoints = [extra_checkpoints]

    print('preparing data')
    # (ds_train_before_config, ds_test_before_config), preprocess_op, rev_preprocess_op = get_dataset(config.dataset)
    # y_transform = lambda y: tf.one_hot(y, config.dataset.num_classes, on_value=1,
    #                                     off_value=0)
    # ds_train = configure_dataloader(ds_train_before_config, batch_size=256, y_transform=y_transform, shuffle=True, seed = seed, cache = dataset_name != 'imagenet2012')
    # ds_test = configure_dataloader(ds_test_before_config, batch_size=256, y_transform=y_transform, shuffle=False, cache = dataset_name != 'imagenet2012')


    (ds_train, ds_test), preprocess_op, rev_preprocess_op, (raw_image_mean, raw_image_std) = get_dataset(config.dataset, apply_aug = True, data_folder = os.path.expanduser(data_path), batch_size = batch_size)

    # dataset, dataset_test, train_sampler, test_sampler = load_data()

    # data_loader = torch.utils.data.DataLoader(
    #     dataset,
    #     batch_size=256,
    #     sampler=train_sampler,
    #     num_workers=4,
    #     pin_memory=True,
    #     collate_fn=None,
    # )
    # data_loader_test = torch.utils.data.DataLoader(
    #     dataset_test, batch_size=256, sampler=test_sampler, num_workers=4, pin_memory=True
    # )
    
    
    
    
    if activation == 'relu':
        activation = nn.ReLU
    else:
        activation = nn.SiLU
    
    torch.cuda.empty_cache()
    print("preparing model")
    if model_name == 'resnet18_silu':
        model = resnet18_silu(num_classes=config.dataset.num_classes, weights = None, activation = activation)
        if dataset_name == 'tiny_imagenet' or 'cifar' in dataset_name:
            model.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
            model.maxpool = nn.Identity()
    elif model_name == 'wrn28-10':
        model = WideResNet(28, config.dataset.num_classes, 10, activation = activation)
    elif model_name == 'wrn28-8':
        model = WideResNet(28, config.dataset.num_classes, 8, activation = activation)
    elif model_name == 'wrn40-4':
        model = WideResNet(40, config.dataset.num_classes, 4, activation = activation)
    elif model_name == 'wrn28-4':
        model = WideResNet(28, config.dataset.num_classes, 4, activation = activation)
    elif model_name == 'wrn22-8':
        model = WideResNet(22, config.dataset.num_classes, 8, activation = activation)
    elif model_name == 'wrn16-8':
        model = WideResNet(16, config.dataset.num_classes, 8, activation = activation)
    elif model_name == 'wrn34-10':
        model = WideResNet(34, config.dataset.num_classes, 10, activation = activation)
    elif model_name == 'wrn10-4':
        model = WideResNet(10, config.dataset.num_classes, 4, activation = activation)
        
    print(activation)

    print("helo!!")
    
    model = model.to(device)
    print("SDKLFJS")

    # model_params = set_weight_decay(model, weight_decay = 1e-4)

    optimizer = optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay = 1e-4)
    cosine_lr_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer, max(n_epochs - 5, 0), eta_min=0.00)

    warmup_lr_schedule = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, total_iters=5
        )

    lr_schedule = torch.optim.lr_scheduler.SequentialLR(
        optimizer, schedulers=[warmup_lr_schedule, cosine_lr_schedule], milestones=[5]
    )

    # aug = torch.nn.Sequential(RandomResizedCropFixed((image_size, image_size), p = 1.0, same_on_batch = False), 
    #                         K.RandomHorizontalFlip( p =0.5, same_on_batch = False))


    if not os.path.exists(os.path.dirname(model_save_path)):
        os.makedirs(os.path.dirname(model_save_path))


    print('training')

    # test_loss, test_acc, _ = train_one_epoch(model, ds_test, optimizer, lr_schedule, preprocess_op, train = False)
    # test_loss, test_acc, _ = train_one_epoch(model, data_loader_test, optimizer, lr_schedule, None, train = False, aug = aug, test_aug = aug_test)
    # print(f'epoch: {0}, test_acc: {test_acc}')
    
    extra_checkpoints = [i for i in range(90)]

    for i in range(n_epochs):
        train_loss, train_acc, _ = train_one_epoch(model, ds_train, optimizer, lr_schedule, preprocess_op, train = True)
        
        
        if i+1 in extra_checkpoints:
            torch.save(model.state_dict(), model_save_path[:-3] + f'_checkpoints_{i+1}.pt')

        # train_loss, train_acc, _ = train_one_epoch(model, data_loader, optimizer, lr_schedule, None, train = True, aug = aug)
        # test_loss, test_acc, _ = train_one_epoch(model, data_loader_test, optimizer, lr_schedule, None, train = False, aug = aug, test_aug = =======)
        # if (i+1)%test_every == 0:
        #     test_loss, test_acc, _ = train_one_epoch(model, ds_test, optimizer, lr_schedule, preprocess_op, train = False)
        #     print(f'epoch: {i + 1}, test_acc: {test_acc}')

    print('Done training!')

    torch.save(model.state_dict(), model_save_path)


# def set_weight_decay(
#     model,
#     weight_decay,
#     norm_weight_decay = None,
#     norm_classes = None,
#     custom_keys_weight_decay = None,
# ):
#     if not norm_classes:
#         norm_classes = [
#             torch.nn.modules.batchnorm._BatchNorm,
#             torch.nn.LayerNorm,
#             torch.nn.GroupNorm,
#             torch.nn.modules.instancenorm._InstanceNorm,
#             torch.nn.LocalResponseNorm,
#         ]
#     norm_classes = tuple(norm_classes)

#     params = {
#         "other": [],
#         "norm": [],
#     }
#     params_weight_decay = {
#         "other": weight_decay,
#         "norm": norm_weight_decay,
#     }
#     custom_keys = []
#     if custom_keys_weight_decay is not None:
#         for key, weight_decay in custom_keys_weight_decay:
#             params[key] = []
#             params_weight_decay[key] = weight_decay
#             custom_keys.append(key)

#     def _add_params(module, prefix=""):
#         for name, p in module.named_parameters(recurse=False):
#             if not p.requires_grad:
#                 continue
#             is_custom_key = False
#             for key in custom_keys:
#                 target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
#                 if key == target_name:
#                     params[key].append(p)
#                     is_custom_key = True
#                     break
#             if not is_custom_key:
#                 if norm_weight_decay is not None and isinstance(module, norm_classes):
#                     params["norm"].append(p)
#                 else:
#                     params["other"].append(p)

#         for child_name, child_module in module.named_children():
#             child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
#             _add_params(child_module, prefix=child_prefix)

#     _add_params(model)

#     param_groups = []
#     for key in params:
#         if len(params[key]) > 0:
#             param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
#     return param_groups

# def load_data():
#     # Data loading code
#     normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
#                                      std=[0.2302, 0.2265, 0.2262])
#     print("Loading training data")
#     train_transform = transforms.Compose([
#             # transforms.RandomResizedCrop(64),
#             # transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             # normalize,
#         ])
#     dataset = TinyImageNet('./data', split='train', download=True, transform=train_transform)

#     print("Loading validation data")
#     val_transform = transforms.Compose([
#         transforms.ToTensor(),
#         # normalize,
#     ])
#     dataset_test = TinyImageNet('./data', split='val', download=False, transform=val_transform)


#     print("Creating data loaders")
#     # if args.distributed:
#     #     if hasattr(args, "ra_sampler") and args.ra_sampler:
#     #         train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
#     #     else:
#     #         train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
#     #     test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
#     # else:
#     train_sampler = torch.utils.data.RandomSampler(dataset)
#     test_sampler = torch.utils.data.SequentialSampler(dataset_test)

#     return dataset, dataset_test, train_sampler, test_sampler


if __name__ == '__main__':
    fire.Fire(main)