import copy
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange

from models import *


class Trainer(object):
    def __init__(self, args, graph, train_nodes, val_nodes, test_nodes):
        self.args = args
        self.graph = graph

        self.device = torch.device(self.args.device)
        torch.cuda.set_device(self.device)

        self.in_channels = self.graph.x.size(1)
        self.hid_channels = self.args.hid_dim
        self.out_channels = int(torch.max(self.graph.y).item() + 1)

        self.train_nodes, self.val_nodes, self.test_nodes = train_nodes, val_nodes, test_nodes
        
        loss_weight = None
        self.criterion = nn.NLLLoss(weight=loss_weight)

    def model_init(self):
        if self.args.model == 'simple-gnn': Model = Simplified_GNN
        
        model = Model(self.args,
                        self.in_channels,
                        self.hid_channels,
                        self.out_channels,
                        self.graph,
                        )
    
        model = model.to(self.device)

        return model
        
    def score(self, graph, model, index_set):
        model.eval()
        with torch.no_grad():
            prediction = model(graph.x, graph.edge_index)
            logits = F.log_softmax(prediction, dim=1)
            val_loss = self.criterion(logits[index_set], graph.y[index_set])

            _, pred = logits.max(dim=1)
            true_false = pred[index_set].eq(graph.y[index_set])
            correct = true_false.sum().item()
            acc = correct / len(index_set)

            return acc, val_loss, true_false, prediction
    
    def bias(self, graph, model, index_set):
        model.eval()
        with torch.no_grad():
            prediction = model(graph.x, graph.edge_index)
            logits = F.log_softmax(prediction, dim=1)
            val_loss = self.criterion(logits[index_set], graph.y[index_set])

            _, pred = logits.max(dim=1)

            labels = graph.y[index_set].long()
            sens = graph.s[index_set].long()
            preds = pred[index_set].long()

            # Statistical Parity
            sp_s1 = preds[sens == 1].float().mean()
            sp_s0 = preds[sens == 0].float().mean()
            sp = torch.abs(sp_s1 - sp_s0).item()

            # Equal Opportunity (condition on labels == 1)
            y1_mask = (labels == 1)
            eo_s1_mask = (sens == 1) & y1_mask
            eo_s0_mask = (sens == 0) & y1_mask

            eo_s1 = preds[eo_s1_mask].float().mean() if eo_s1_mask.any() else torch.tensor(0.0)
            eo_s0 = preds[eo_s0_mask].float().mean() if eo_s0_mask.any() else torch.tensor(0.0)
            eo = torch.abs(eo_s1 - eo_s0).item()

            return sp, eo

    def fit(self, graph, model):
        optimizer = torch.optim.Adam(model.parameters(), 
                                        lr=self.args.lr, 
                                        weight_decay=self.args.dr)
            
        iterator = trange(self.args.epochs, desc='Val loss: ', leave=False)

        step_counter = 0
        self.best_val_acc = 0
        self.best_val_loss = np.inf

        for _ in iterator:
            model.train()
            optimizer.zero_grad()

            prediction = model(graph.x, graph.edge_index)
            prediction = F.log_softmax(prediction, dim=1)

            loss = F.nll_loss(prediction[self.train_nodes], graph.y[self.train_nodes])
            loss.backward()
            optimizer.step()

            val_acc, val_loss, val_corr, val_logits = self.score(graph, model, self.val_nodes)
            iterator.set_description("Val Loss: {:.4f}".format(val_loss))

            if val_loss <= self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_acc = val_acc
                best_model = copy.deepcopy(model)
                step_counter = 0
                
            else:
                step_counter = step_counter + 1
                if step_counter > self.args.patience:    
                    iterator.close()
                    break

        return best_model

    def eval(self, graph, best_model):
        train_acc, train_loss, train_corr, train_logits = self.score(graph, best_model, self.train_nodes)
        val_acc, val_loss, val_corr, val_logits = self.score(graph, best_model, self.val_nodes)
        test_acc, test_loss, test_corr, test_logits = self.score(graph, best_model, self.test_nodes)
        sp, eo = self.bias(graph, best_model, self.test_nodes)

        return train_acc, val_acc, test_acc, test_corr, sp, eo
