import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
import random

from spn.algorithms.layerwise.distributions import Normal
from spn.algorithms.layerwise.layers import Sum, Product

from spn.experiments.RandomSPNs_layerwise.distributions import RatNormal
from spn.experiments.RandomSPNs_layerwise.rat_spn import RatSpn, RatSpnConfig

from spn.algorithms.layerwise.layers import CrossProduct, Sum

from spn.io.Graphics import plot_spn

from data_loader import TrafficData
from utils import set_seed, count_params


class SPNClassifier(nn.Module):
    '''
    np.r_: stack the first dim
    np.c_: stack the second dim
    '''

    def __init__(self):
        super(SPNClassifier, self).__init__()
        # self.train_data = np.c_[np.r_[np.random.normal(5, 1, (500, 2)),
        #                               np.random.normal(10, 1, (500, 2))],
        #                         np.r_[np.zeros((500, 1)), np.ones((500, 1))]]
       
        self.gauss = Normal(in_features=4, out_channels=2)
        self.prod1 = Product(in_features=4, cardinality=2)
        self.sum1 = Sum(in_features=2, in_channels=2, out_channels=1)
        self.prod2 = Product(in_features=2, cardinality=2)

        self.spn = nn.Sequential(
                                 self.gauss,
                                 self.prod1,
                                 self.sum1,
                                 self.prod2,
                                 )


    def forward(self, x):
        result = self.spn(x)

        return result


def make_spn(S, I, R, D, dropout, device) -> RatSpn:
    """
    Construct the RatSpn.
    
    Args:
        F: Number of input features
        D: Tree depth
        S: Number of sum nodes at each layer
        I: Number of distributions for each scope at the leaf layer
        R: Number of repetitions of features
        C: Number of root heads / Number of classes
    """

    # Setup RatSpnConfig
    config = RatSpnConfig()
    config.F = 10 # 10, 7
    config.R = R
    config.D = D
    config.I = I
    config.S = S
    config.C = 5 # 5, 3
    config.dropout = dropout
    config.leaf_base_class = RatNormal
    config.leaf_base_kwargs = {}

    # Construct RatSpn from config
    model = RatSpn(config)

    model = model.to(device)
    model.train()

    print("Using device:", device)
    return model


def get_weights(model):
    print(model)
    print(model._leaf)
    # for layer in model._inner_layers:
    #     print(layer)
    #     if isinstance(layer, CrossProduct):
    #         pass
    #     if isinstance(layer, Sum):
    #         print(layer.weights)
    print(model._inner_layers)
    print(model._sampling_root)
    pass


if __name__ == '__main__':
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:2" if use_cuda else "cpu")
    batch_size = 50
    # device = "cpu"
    set_seed(123)

    train_path = '../ControlVAE/data_class_v5/train_z_label_semi.csv'
    train_data = TrafficData(train_path)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=1)
    
    test_path = '../ControlVAE/data_class_v5/test_z_label_semi.csv'
    test_data = TrafficData(test_path)
    test_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=1)

    model = make_spn(S=3, I=3, R=2, D=3, device=device, dropout=0.0)
    # model = make_spn(S=2, I=2, R=2, D=2, device=device, dropout=0.0)
    # model = SPNClassifier()
    print(model)
    print("Number of pytorch parameters: ", count_params(model))


    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    model.to(device)
    model.train()

    log_interval = 100
    n_epochs = 80
    global_iter = 0

    for epoch in range(n_epochs):
        losses = 0.0
        for batch_idx, (data, shape_label, color_label) in enumerate(train_loader):
            target = shape_label.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            # Reset gradients
            optimizer.zero_grad()

            # Inference
            out = model(data)

            # Compute loss
            loss = loss_fn(out, target)

            # Backpropgation
            loss.backward()
            optimizer.step()

            losses += loss.item()


            if batch_idx % log_interval == (log_interval-1):
            # if global_iter % log_interval == 0:
                predicted = out.argmax(1)
                acc_count = predicted.eq(target).sum().cpu().numpy()
                acc = acc_count / data.shape[0] * 100

                print("Train epoch: {} [{: >5}/{: <5} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.2f}%".format(
                    epoch,
                    (batch_idx + 1) * len(data),
                    5000,
                    100.0 * (batch_idx + 1) / len(train_loader),
                    losses / log_interval,
                    acc,
                    ),
                )

    ## for testing
    total_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, shape_label, color_label) in enumerate(test_loader):
            target = shape_label.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)

            out_y = model(data)

            predicted = torch.max(out_y, 1)[1]
            accy_count = (predicted==target).sum().item()
            print("accy_count: ", accy_count)
            total_correct += accy_count
            total += len(shape_label)
    accy = total_correct/total
    print("testing accy: ", accy)

    # plot_spn(model)
    # get_weights(model)
