import torchvision.transforms as transforms

from .models import *
from .data import *


def prepare_data_and_model(dataset_name, config=None):
    if dataset_name == "mnist":
        img_transforms = [
            transforms.ToTensor(), 
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
        train_set, test_set = get_torchvision_dataset(dataset_name, img_transforms)
        model_config = {
            "type": "mlp",
            "in_size": 28 * 28, 
            "out_size": 10, 
            "hidden": 16,
        } 
        model = get_model(model_config)
    elif dataset_name == "fashion_mnist":
        img_transforms = [
            transforms.ToTensor(), 
            transforms.Normalize((0.5,), (0.5,)),
        ]
        train_set, test_set = get_torchvision_dataset(dataset_name, img_transforms)
        model_config = {
            "type": "convnet2",
            "in_channels": 1, 
            "input_size": 28, 
            "out_size": 10, 
            "hidden": 10,
            "n_kernels": 8,
        }
        model = get_model(model_config)
    elif dataset_name == "fashion_mnist_resize":
        # img_transforms = [
        #     transforms.ToTensor(), 
        #     transforms.Resize(8, antialias=True),
        #     # transforms.normalize((0.5,), (0.5,)),
        # ]
        # train_set, test_set = get_torchvision_dataset("fashion_mnist", img_transforms)
        train_set = FashionMNISTSubset(train=True)
        test_set = FashionMNISTSubset(train=False)
        model_config = {
            "type": "convnet2_resize",
            "in_channels": 1,
            "in_size": 8,
            "out_size": 8,
            "hidden": 10
            # "inp_dim": (1,8,8),
            # "out_dim": 8, 
            # "last_linear_dim": 10,
        }
        model = get_model(model_config)
        
    elif dataset_name == "cifar10":
        train_img_transforms = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ]
        test_img_transforms = [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ]
        train_set, test_set = get_torchvision_dataset(dataset_name, train_img_transforms, test_img_transforms)
        model_config = {
            "type": "resnet18",
            "in_channels": 3,
            "out_size": 10,
        }   # acc: ~91%
        model = get_model(model_config)

    elif dataset_name == "ag_news":
        model_config = {
            "type": "llama",
            "task": "classification",
            "out_size": 4,
            "torch_dtype": torch.bfloat16,
        }
        model = get_model(model_config)
        train_set, test_set = get_hf_dataset(dataset_name, tokenizer=model[1])  # raw dataset (not being tokenized)

    elif dataset_name == "tofu":
        model_config = {
            "type": "llama-chat",
            "task": "qa",
            "torch_dtype": torch.bfloat16,
        }
        model = get_model(model_config, config.peft_config)
        train_set, test_set = get_hf_dataset(dataset_name, tokenizer=model[1], max_seq_length=config.max_seq_length)  # raw dataset (not being tokenized)

    else:
        raise NotImplementedError

    return (train_set, test_set), model, model_config

