import torch
import torch.nn as nn
from data_loader import TrafficData
from torch.utils.data import DataLoader,Dataset,TensorDataset
import torch.optim as optim
from torch.autograd import Variable

import os
import sys
import wandb

sys.path.append(sys.path[0] + '/..')
from utils import set_seed, count_params


class MLPModel(nn.Module):
    """ Simple MLP classifer for shapes. """
    def __init__(self, out_size=1):
        super(MLPModel, self).__init__()
        # for data-v5
        self.mlp = nn.Sequential(nn.Linear(3, out_size),
                                 nn.Softmax(dim=1)
                                )


    def forward(self, x):
        result = self.mlp(x)

        return result


def main():
    set_seed(42)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:2" if use_cuda else "cpu")

    num_epoch = 120
    out_size = 8
    Batch_size = 50
    attack_types = ['all', 'elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']

    # dataset = "data_class_vSimon2"
    # dataset = "data_class_v5"
    # dataset = "traffc_data"
    dataset = 'sticker'
    lr = 0.0001

    # start a new wandb run to track this script
    # wandb.init(
    #     # set the wandb project where this run will be logged
    #     project="SW-ControlVAE-MLP",
        
    #     # track hyperparameters and run metadata
    #     config={
    #         "learning_rate": lr,
    #         "architecture": "MLP-[3,10,20,5]",
    #         "dataset": "data_class_vSimon2",
    #         "epochs": num_epoch,
    #     }
    # )
    
    ## import data
    if "data_class_v5" in dataset.lower():
        training_set = TrafficData('../ControlVAE/data_class_v5/train_z_label_semi.csv')
        test_set = TrafficData('../ControlVAE/data_class_v5/test_z_label_semi.csv')
        feature_indices = [0, 1, 3]
    elif "data_class_vsimon2" in dataset.lower():
        training_set = TrafficData('../ControlVAE/data_class_vSimon2/train_z_label_semi.csv')
        test_set = TrafficData('../ControlVAE/data_class_vSimon2/test_z_label_semi.csv')
        feature_indices = [1, 7, 9]
    elif dataset == "traffc_data":
        training_set = TrafficData('/data/open-datasets/traffic/train/train_z_label.csv')
        test_set = TrafficData('/data/open-datasets/traffic/val/val_z_label.csv')
        feature_indices = [2, 8, 9]
    elif dataset in attack_types:
        training_set = TrafficData('/data/open-datasets/traffic/train/train_z_label.csv')
        test_set = TrafficData(f'/data/open-datasets/traffic/val/val_z_label_{dataset}.csv')
        feature_indices = [2, 8, 9]
    else:
        raise NameError(dataset)

    train_generator = DataLoader(training_set, batch_size=Batch_size,\
                                    shuffle=True,num_workers=1)
    test_generator = DataLoader(test_set, batch_size=Batch_size,\
                                shuffle=False,num_workers=1)
    
    model = MLPModel(out_size)
    print("Number of pytorch parameters: ", count_params(model))
    
    # optimizer = optim.Adam(model.parameters(),lr=lr)  ## optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    # loss_fun = nn.BCELoss()
    loss_fun = nn.CrossEntropyLoss()
    model.to(device)
    model = model.train()
    global_step = 0
    
    for epoches in range(num_epoch):
        for x, class_label, shape_label, color_label in train_generator:
            label = class_label # define feature type here
            label = label.type(torch.LongTensor)
            label = Variable(label).to(device)
            global_step += 1
            ## define feature index of z
            # z[0, 1, 3] for v5 data
            # z[1, 7, 9] for vSimon2 data
            x1 = x[:, feature_indices[0]: feature_indices[0]+1]
            x2 = x[:, feature_indices[1]: feature_indices[1]+1]
            x3 = x[:, feature_indices[2]: feature_indices[2]+1]
            x = torch.cat((x1, x2, x3), 1)
            # print(x.shape)
            # input()
            x = Variable(x).to(device)

            labels = torch.zeros((Batch_size, out_size))
            for i in range(labels.shape[0]):
                labels[i][label[i]] = 1.
            labels = Variable(labels).to(device)

            out_y = model(x)
            loss = loss_fun(out_y, labels)

            ## back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % 100 == 0:

                ## For binary shapes
                # out_y = torch.round(out_y)
                # accy_count = (out_y==label).sum().item()
                # accy = accy_count/len(label)
                # print("epoch: {0} step:{1} accy: {2} loss: {3}".format(epoches, global_step, accy,loss.item()))

                ## For multiple shapes
                predicted = torch.max(out_y, 1)[1]
                accy_count = (predicted==label).sum().item()
                accy = accy_count/len(label)
                print("epoch: {0} step:{1} accy: {2} loss: {3}".format(epoches, global_step, accy, loss.item()))

                # wandb.log({"epoch": epoches, "acc": accy, "loss": loss.item()})

    # wandb.finish()


    ## for testing
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for batch_id, (x, class_label, shape_label, color_label) in enumerate(test_generator):
            label = class_label
            ## define feature index of z
            x1 = x[:, feature_indices[0]: feature_indices[0]+1]
            x2 = x[:, feature_indices[1]: feature_indices[1]+1]
            x3 = x[:, feature_indices[2]: feature_indices[2]+1]
            # x1 = x[:,1:2]
            # x2 = x[:,7:8]
            # x3 = x[:,9:10]
            x = torch.cat((x1, x2, x3), 1)
            x = Variable(x).to(device)
            label = Variable(label).to(device)
            out_y = model(x)

            ## for binary shapes
            # out_y = torch.round(out_y)
            # accy_count = (out_y==label).sum().item()

            predicted = torch.max(out_y, 1)[1]
            accy_count = (predicted==label).sum().item()
            print("accy_count: ", accy_count)
            total_correct += accy_count
            total += len(label)
    accy = total_correct/total
    print("testing accy: ", accy)



if __name__ == "__main__":
    main()