import torch
from torch.utils.data import Subset, DataLoader, Dataset, random_split
import os
import argparse
from torchvision import datasets, transforms
import numpy as np
import random
from cuda_selector import auto_cuda
import tqdm
from torch import nn
from torchvision.transforms.functional import crop

parser = argparse.ArgumentParser(description="CelebA WAE", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--data_dir", type=str, default="dataset", help="directory of CelebA")
parser.add_argument("--save_dir", type=str, default="models/", help="directory to save models")
parser.add_argument("--seed", type=int, default=12345678, help="random seed")
parser.add_argument("--epochs", type=int, default=500, help="Total number of epochs")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--z_dim", type=int, default=20)
parser.add_argument("--target_class", type=int, default=0)

args = parser.parse_args()

# -------------------- Seed ------------------------------
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)  # cpu
torch.cuda.manual_seed_all(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
# torch.backends.cudnn.enabled = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['PYTHONHASHSEED'] = str(args.seed)

cuda_available = torch.cuda.is_available()


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------- Load and preprocess data -------------------------------
class TransformDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, _ = self.dataset.__getitem__(index)
        x = img.detach().clone()
        x[:, 14:, :14] = 0
        y = img[:, 14:, :14].detach().clone()
        return x, y

data_root = args.data_dir #'/home/kychong/Documents/dataset'
full_dataset = datasets.MNIST(
    root=data_root,
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

nontarget_idx = (full_dataset.targets != args.target_class)
nontarget_idx = torch.where(nontarget_idx)[0]
full_dataset = Subset( full_dataset, indices=nontarget_idx)

test_dataset = datasets.MNIST(
    root=data_root,
    train=False,
    transform=transforms.ToTensor(),
    download=True
)
nontarget_idx = (test_dataset.targets != args.target_class)
nontarget_idx = torch.where(nontarget_idx)[0]
test_dataset = Subset( test_dataset, indices=nontarget_idx)


train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

# Perform the random split
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_dataset = TransformDataset(train_dataset) # 48000
valid_dataset = TransformDataset(val_dataset) # 12000
test_dataset = TransformDataset(test_dataset)

save_dir = args.save_dir

# ------------------- Select Device -------------------------------
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if cuda_available:
    device_id = int(auto_cuda()[-1])
    torch.cuda.set_device(device_id)
    device = torch.device('cuda')
    print('Using cuda:{}'.format(torch.cuda.current_device()))

else:
    device = 'cpu'
    print('Using CPU')


print('Train set: {}/ Test set: {}'.format(len(train_dataset), len(test_dataset)))

# ------------------- Load Parameters -------------------------------
z_dim = args.z_dim
lr = args.lr
batch_size=args.batch_size
epochs=args.epochs

# ------------------- Load Models -------------------------------
WAE_dir = 'MNIST_WAE_zdim{}_cornercrop_all.pth'.format( z_dim)
checkpoint = torch.load(WAE_dir)
from models import MNISTEncoder
E = MNISTEncoder(z_dim=args.z_dim).to(device)
E.load_state_dict(checkpoint['encoder'])
for p in E.parameters():
    p.requires_grad = False

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
from models import MNIST_NN_CornerCrop
nn_model = MNIST_NN_CornerCrop(z_dim=z_dim).to(device)

optimizer = torch.optim.Adam([{'params': nn_model.parameters()}, ], lr=lr)

# ------------------- Training Loop -------------------------------
best_val_loss = float('inf')
patience = 5
wait = 0
best_model_weights = None
mse = nn.MSELoss()
train_loss = 0
val_loss = 0
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min",
                                                       patience=10,
                                                       factor=0.5)

with tqdm.tqdm(range(epochs)) as pbar:
    for i in pbar:
        nn_model.train()
        train_loss = 0
        count = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            z = E(x)
            y_pred = nn_model(z)

            loss = mse(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            count += 1
        train_loss /= count
        pbar.set_postfix_str('Train: {} / Valid: {} / Best Valid: {}'.format(train_loss, val_loss, best_val_loss))

        nn_model.eval()
        val_loss = 0
        count = 0
        with torch.no_grad():  # No gradients needed
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                z = E(x)
                y_pred = nn_model(z)
                loss = mse(y_pred, y)
                val_loss += loss.item()
                count += 1

        val_loss /= count
        pbar.set_postfix_str('Train: {} / Valid: {} / Best Valid: {}'.format(train_loss, val_loss, best_val_loss))

        scheduler.step(val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            wait = 0
            # Save best model weights
            best_model_weights = {
                'nn_model': nn_model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(best_model_weights, 'MNIST_NN_zdim{}_cornercrop_all_{}_WAE.pth'.format(z_dim, args.target_class))
        else:
            wait += 1

        pbar.set_postfix_str('Train: {} / Valid: {} / Best Valid: {}'.format(train_loss, val_loss, best_val_loss))

torch.save(best_model_weights, 'MNIST_NN_zdim{}_cornercrop_all_{}_WAE.pth'.format(z_dim, args.target_class))