import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transforms

from sensor import *

class Classifier(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        # define layers

        # IndependentAnisotropicHalfNormalTunableSigmoidDeformation = rect
        # AnisotropicHalfNormalTunableSigmoidDeformation = curv
        self.sensor = FoveatedSensor((16, 16),
                                     deform=IndependentAnisotropicHalfNormalTunableSigmoidDeformation(1.0), constrain_t=True)
        self.sensor.t.requires_grad = True # set to false to train without sensor optimization
        
        # pretrained resnet
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

        # small head
        self.resnet.fc = nn.Sequential(
            nn.Linear(512, 40),
        )

    def forward(self, x):
        x = self.sensor(x)
        return F.sigmoid(self.resnet(x))

class celeba(Dataset):
    def __init__(self, data_path=None, label_path=None):
        self.data_path = data_path
        self.label_path = label_path
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):
        image_set = Image.open(self.data_path[idx])
        image_tensor = self.transform(image_set)
        image_label = torch.Tensor(self.label_path[idx])

        return image_tensor, image_label

if __name__ == '__main__':
    root_dir = 'data/celeba'
    data_path = sorted(glob.glob(os.path.join(root_dir, 'img_align_celeba/*.jpg')))
    label_path = os.path.join(root_dir, 'list_attr_celeba.txt')
    label_list = open(label_path).readlines()[2:]
    data_label = []
    for i in range(len(label_list)):
        data_label.append(label_list[i].split())

    for m in range(len(data_label)):
        data_label[m] = [n.replace('-1', '0') for n in data_label[m]][1:]
        data_label[m] = [int(p) for p in data_label[m]]

    attributes = open(label_path).readlines()[1].split()

    dataset = celeba(data_path, data_label)
    indices = list(range(202599))
    split_train = 141819
    split_val = 182339

    train_idx, val_idx, test_idx = indices[:split_train], indices[split_train:split_val], indices[split_val:]

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    train_loader = DataLoader(dataset, batch_size=64, num_workers=8, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=64, num_workers=8,  sampler=val_sampler)
    test_loader = DataLoader(dataset, num_workers=8,  sampler=test_sampler)

    def train(model, optimizer, criterion, epochs, train_all_losses, train_all_acc):
        model.train()

        # initialize the running loss
        running_loss = 0.0
        # pick each data from trainloader i: batch index/ data: inputs and labels
        correct = 0
        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            labels = torch.Tensor(labels)
            # print(type(labels))
            inputs = inputs.to('cuda')
            labels = labels.to('cuda')
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            # print statistics
            running_loss += loss.item()
            # backpropagation
            loss.backward()
            # update parameters
            optimizer.step()
            
            result = outputs > 0.5
            correct += (result == labels).sum().item() 

            if i % 64 == 0: 
                print('Training set: [Epoch: %d, Data: %6d] Loss: %.3f t: %s' %
                    (epochs + 1, i * 64, loss.item(), str(model.sensor.t.detach().cpu().numpy().tolist())))
                
    
        acc = correct / (split_train * 40)
        running_loss /= len(train_loader)
        train_all_losses.append(running_loss)
        train_all_acc.append(acc)
        print('\nTraining set: Epoch: %d, Accuracy: %.2f %%' % (epochs + 1, 100. * acc))

    def validation(model, criterion, val_all_losses, val_all_acc, best_acc):
        model.eval()
        validation_loss = 0.0
        correct = 0
        for data, target in val_loader:
            data = data.to('cuda')
            target = target.to('cuda')
            output = model(data)

            validation_loss += criterion(output, target).item()

            result = output > 0.5
            correct += (result == target).sum().item()


        validation_loss /= len(val_loader)
        acc = correct / ((split_val - split_train) * 40)

        val_all_losses.append(validation_loss)
        val_all_acc.append(acc)

        
        print('\nValidation set: Average loss: {:.3f}, Accuracy: {:.2f}%)\n'
            .format(validation_loss, 100. * acc))
        
        return acc

    def test(model, criterion, attr_acc, attr_name=attributes):
        test_loss = 0
        correct = 0
        pred = []
        for data, target in test_loader:
            data = data.to('cuda')
            target = target.to('cuda')
            output = model(data)
            test_loss += criterion(output, target).item()

            result = output > 0.5
            correct += (result == target).sum().item()
            compare = (result == target)
            pred.append(compare[0])


        test_loss /= len(test_loader)
        acc = correct / (len(test_loader) * 40)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
            test_loss, 100. * acc))
        
        for m in range(len(attr_name)):
            num = 0
            for n in range(len(pred)):
                if pred[n][m]:
                    num += 1
            accuracy = num / len(pred)
            attr_acc.append(accuracy)

        for i in range(len(attr_acc)):
            print('Attribute: %s, Accuracy: %.3f' % (attr_name[i], attr_acc[i]))


        
    train_all_losses2 = []
    train_all_acc2 = []
    val_all_losses2 = []
    val_all_acc2 = []
    attr_acc = []
    test_all_losses2 = 0.0
    # define the training epoches
    epochs = 30

    model = Classifier()
    # use cuda to train the network
    model.to('cuda')
    #loss function and optimizer
    criterion = nn.BCELoss()
    learning_rate = 1e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

    best_acc = 0.0
    for epoch in range(epochs):
        train(model, optimizer, criterion, epoch, train_all_losses2, train_all_acc2)
        acc = validation(model, criterion, val_all_losses2, val_all_acc2, best_acc)
        if acc > best_acc:
            checkpoint_path = './model_resnet_small_head_w_sensor_16.pth'
            best_acc = acc
            # save the model and optimizer
            torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
            print('new best model saved')
            print("========================================================================")

    model.load_state_dict(torch.load('./model_resnet_small_head_w_sensor_16.pth')["model_state_dict"])
    acc = test(model, criterion, attr_acc)