import importlib
import os
import sys
import time
from collections import OrderedDict

import torch
import torch.nn as nn
from torchvision.ops import MLP

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    FairMetrics,
    get_parameter_number
)

from torch_geometric.utils import degree
from models.cvae import cvae, loss_function
from copy import deepcopy

class GNNWithMLP(nn.Module):
    def __init__(self, args, gnn):
        super().__init__()
        self.gnn = gnn
        self.MLP = MLP(args.A + args.z_size, [args.input_size])
        
    def forward(self, data):
        A = self.gnn(2 * data.s - 1, data.edge_index)
        za = torch.cat((data.z, A), dim=1)
        x = self.MLP(za)
        
        return x
        

def train(epoch, model, optimizer, criterion, data):
    model.train()
    loss_val = 0.0
    
    x_pred = model(data)
    loss = criterion(x_pred, data.x)
    loss_val += loss.item()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    print(f'Epoch:{epoch}, loss:{loss_val / len(data.x)}')
    return loss_val / len(data.x)

def do_inference(model, criterion, data, device):
    model.eval()
    with torch.no_grad():
        pos_loss_val = 0.0
        neg_loss_val = 0.0
        loss_val = 0.0
        
        pos_do = torch.ones_like(data.s).to(device)
        neg_do = torch.zeros_like(data.s).to(device)
        
        pos_data = deepcopy(data)
        pos_data.update_tensor(pos_do, 's')
        pos_x_pred = model(pos_data)
        pos_loss = criterion(pos_x_pred, data.do_pos_x)
        del pos_x_pred, pos_data
        
        neg_data = deepcopy(data)
        neg_data.update_tensor(neg_do, 's')
        neg_x_pred = model(neg_data)
        neg_loss = criterion(neg_x_pred, data.do_neg_x)
        del neg_x_pred, neg_data
        
        x_pred = model(data)
        loss = criterion(x_pred, data.x)
        p_pred = x_pred * data.s
        p_pred = p_pred[p_pred != 0]
        n_pred = x_pred * (1 - data.s)
        n_pred = n_pred[n_pred != 0]
        p_x = data.x * data.s
        p_x = p_x[p_x != 0]
        n_x = data.x * (1 - data.s)
        n_x = n_x[n_x != 0]
        
        loss_p = criterion(p_pred, p_x)
        loss_n = criterion(n_pred, n_x)
        print(loss_p / sum(data.s))
        print(loss_n / sum(1 - data.s))
        del x_pred
        
        pos_loss_val += pos_loss.item()
        neg_loss_val += neg_loss.item()
        loss_val += loss.item()
        
        print(f'do s+ loss:{pos_loss_val/ len(data.x)}, do s- loss:{neg_loss_val/ len(data.x)}, reconstruct loss:{loss_val/ len(data.x)}')
    


def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)
        
        

    setup_seed(args.gtseed)
    device = "cuda:1" if torch.cuda.is_available() else "cpu" 
        
    
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
    dataset = dataset.to(device)
    print(dataset)
    print(torch.mean(degree(dataset.edge_index[0], dataset.x.shape[0])))
    
    
    setup_seed(args.seed)
    
    gnn = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim, args.z_size).gnn
    model = GNNWithMLP(args, gnn)
    model.to(device)
    
    
    # auto grad 
    optimizer = torch.optim.Adam(model.parameters(), lr= 0.01)
    criterion = nn.MSELoss(reduction='sum')
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
    
    # start training
    
    for epoch in range(1000):
        train(epoch, model, optimizer, criterion, dataset)
       
        
    do_inference(model, criterion, dataset, device)
    torch.save(model.gnn.state_dict(), 'weights/gnn.pt')
    torch.save(model.state_dict(), 'weights/gnn_w_mlp.pt')

    
    
    
if __name__ == '__main__':
    main()