import argparse
import os

import numpy as np
import pytorch_lightning as pl
import torch
from data import LabeledDataset, load_epithel, load_mice, load_activity, load_aml
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.utils.data import DataLoader

from models.concrete_autoencoder import ConcreteAutoencoder
from models.CFS_Gates import CFS_Gates
from models.CFS_Joint import CFS_Joint
from models.CFS_Pretrained import CFS_Pretrained

from utils import set_seeds

"""
Script for training CFS models on the Grassy MNIST dataset. Before running
this script please run the notebook `get_grassy_mnist.ipynb`.
"""

parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, choices=[
    'ConcreteAutoencoder',
    'CFS_Joint',
    'CFS_Pretrained',
    'CFS_Gates'
])
parser.add_argument('k', type=int)
parser.add_argument('max_epochs', type=int)

k_prime = 20

args = parser.parse_args()

set_seeds(42)

dataset = "Grassy_MNIST"

batch_size = 128
background = np.load("./data/{}/background.npy".format(dataset)).astype(np.float32)
target = np.load("./data/{}/target.npy".format(dataset)).astype(np.float32)
target_labels = np.load("./data/{}/target_labels.npy".format(dataset)) + 1
background_labels = np.zeros(background.shape[0])
hidden = [512, 512]

# For the contrastive models, we use background and target data during training
if args.model.startswith('CFS'):
    bg_train, bg_test = train_test_split(background, test_size=0.2, random_state=42)
    target_train, target_test = train_test_split(target, test_size=0.2, random_state=42)

    data_train = np.concatenate([bg_train, target_train])
    labels_train = np.concatenate([np.zeros(bg_train.shape[0]), np.ones(target_train.shape[0])])

# For the non-contrastive setting, we use _only_ the target data during training
else:
    data_train, data_test = train_test_split(target, test_size=0.2, random_state=42)


if args.model == 'ConcreteAutoencoder':
    input_size = target.shape[1]
    output_size = target.shape[1]
    dataset = LabeledDataset(data_train, data_train)
    model = ConcreteAutoencoder(
        input_size=input_size,
        output_size=output_size,
        start_temperature=10.0,
        end_temperature=0.01,
        hidden=hidden,
        k=args.k,
        train_length=target.shape[0],
        max_epochs=args.max_epochs,
        mbsize=batch_size,
        lr=1e-3,
        loss_fn=nn.MSELoss()
    )

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=12)

    trainer = pl.Trainer(max_epochs=args.max_epochs, gpus=[0])
    trainer.fit(model, loader)

elif args.model == 'CFS_Pretrained':
    input_size = target.shape[1]
    output_size = target.shape[1]
    dataset = LabeledDataset(data_train, labels_train)
    max_pretrain_epochs = 100

    model = CFS_Pretrained(
        input_size=input_size,
        output_size=output_size,
        start_temperature=10.0,
        end_temperature=0.01,
        hidden=hidden,
        k=args.k,
        k_prime=20,
        train_length=target.shape[0],
        max_pretrain_epochs=max_pretrain_epochs,
        max_fs_epochs=args.max_epochs,
        mbsize=batch_size,
        lr=1e-3,
        loss_fn=nn.MSELoss()
    )

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=12)

    trainer = pl.Trainer(max_epochs=max_pretrain_epochs + args.max_epochs, gpus=[0])
    trainer.fit(model, loader)

elif args.model == 'CFS_Joint':
    input_size = data_train.shape[1]
    output_size = data_train.shape[1]

    dataset = LabeledDataset(data_train, labels_train)
    model = CFS_Joint(
        input_size=input_size,
        output_size=output_size,
        start_temperature=10.0,
        end_temperature=0.01,
        hidden=hidden,
        k=args.k,
        k_prime=k_prime,
        train_length=data_train.shape[0],
        max_epochs=args.max_epochs,
        mbsize=batch_size,
        lr=1e-3,
        loss_fn=nn.MSELoss()
    )
    loader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=12)

    trainer = pl.Trainer(max_epochs=args.max_epochs, gpus=[0])
    trainer.fit(model, loader)

elif args.model == 'CFS_Gates':
    input_size = data_train.shape[1]
    output_size = data_train.shape[1]

    dataset = LabeledDataset(data_train, labels_train)
    model = CFS_Gates(
        input_size=input_size,
        output_size=output_size,
        start_temperature=10.0,
        end_temperature=0.01,
        hidden=hidden,
        k=args.k,
        k_prime=k_prime,
        train_length=data_train.shape[0],
        max_epochs=args.max_epochs,
        mbsize=batch_size,
        lr=1e-3,
        loss_fn=nn.MSELoss()
    )
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=12)

    trainer = pl.Trainer(max_epochs=args.max_epochs, gpus=[0])
    trainer.fit(model, loader)
else:
    raise NotImplementedError("Invalid model type selected")

checkpoint_dir = "results/" + args.dataset + "/" + args.model + "/" + str(args.k) + "/"

os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(model, checkpoint_dir + "checkpoint.chkpt")

