import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import copy
from dataset import MetricDataset
from metric import Predictor
from tqdm import tqdm


tp = 'om'
data_root = '~/data/tai/metric_{}'.format(tp)
metric_datasets = {
    x: MetricDataset(
        data_root=data_root,
        split=x,
        image_transform = True if x == 'train' else False
    )
    for x in ['train', 'val']
}
dataloaders = {
    x: DataLoader(metric_datasets[x], batch_size=64, shuffle=True, num_workers=8)
    for x in ['train', 'val']
}
dataset_sizes = {x: len(metric_datasets[x]) for x in ['train', 'val']}


model = Predictor()
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 30)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):
                # print(inputs)
                # print("labels:",labels)
                inputs = inputs.to(device).float()
                labels = labels.to(device).float()
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f}')

            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        scheduler.step()
        print()

    print(f'Best val Loss: {best_loss:4f}')
    model.load_state_dict(best_model_wts)
    return model


model = train_model(model, criterion, optimizer, scheduler, num_epochs=50)
torch.save(model.state_dict(), '{}_metric_new.pth'.format(tp))
