import os
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import random
import torch.nn.functional as F
from utils.loader import load_seed, load_device, load_data, load_model_params, load_model_optimizer, load_loss_fn, \
                         load_simple_model_params, load_simple_model_optimizer, load_simple_loss_fn
from utils.logger import Logger, set_log, start_log, train_log
import dgl


def augment_node_features(graph, node_features, use_sgc_features, use_identity_features, use_adjacency_features,
                            do_not_use_original_features):

    n = graph.num_nodes()
    original_node_features = node_features

    if do_not_use_original_features:
        node_features = torch.tensor([[] for _ in range(n)])

    if use_sgc_features:
        sgc_features = Dataset.compute_sgc_features(graph, original_node_features)
        node_features = torch.cat([node_features, sgc_features], axis=1)

    if use_identity_features:
        node_features = torch.cat([node_features, torch.eye(n)], axis=1)

    if use_adjacency_features:
        graph_without_self_loops = dgl.remove_self_loop(graph)
        adj_matrix = graph_without_self_loops.adjacency_matrix().to_dense()
        node_features = torch.cat([node_features, adj_matrix], axis=1)

    return node_features



class Trainer(object):
    def __init__(self, config):
        super(Trainer, self).__init__()

        self.config = config
        self.log_folder_name, self.log_dir = set_log(self.config)
        self.seed = load_seed(self.config.seed)
        self.device = load_device()
        self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask = load_data(self.config)
        self.losses = load_loss_fn(self.config, self.device)
        self.simple_losses = load_simple_loss_fn(self.config, self.device)

        adjs, n_samples = [], 5
        for i in range(0,n_samples):
            adjs.append(self.adj+self.x.shape[0]*i)
        self.adjs = torch.cat(adjs, dim = 1)

        self.adj = dgl.graph((self.adj[0, :], self.adj[1, :]), num_nodes=self.x.shape[0], idtype=torch.int)
        self.adj = dgl.to_bidirected(self.adj.cpu())
        self.adj = self.adj.to(self.device[0])

        self.adjs = dgl.graph((self.adjs[0, :], self.adjs[1, :]), num_nodes=self.x.shape[0]*n_samples, idtype=torch.int)
        self.adjs = dgl.to_bidirected(self.adjs.cpu())
        self.adjs = self.adjs.to(self.device[0])
        

    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        # Prepare model, optimizer, and logger
        self.params = load_model_params(self.config)
        self.simple_params = load_simple_model_params(self.config)
        self.simple_model, self.simple_optimizer, self.simple_scheduler = load_simple_model_optimizer(self.simple_params, self.config.train, self.device)
        self.model, self.optimizer, self.scheduler = load_model_optimizer(self.params, self.config.train, self.device)
        self.loss_fn = self.losses.loss_fn
        self.simple_loss_fn = self.simple_losses.loss_fn
        self.estimator = self.losses.estimate
        self.mc_estimator = self.losses.mc_estimate
        self.simple_estimator = self.simple_losses.estimate          

        logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}.log')), mode='a')
        logger.log(f'{self.ckpt}', verbose=False)
        start_log(logger, self.config)
        train_log(logger, self.config)

        # Pre-train mean-field GNN
        best_valid, best_est, base_hetero = 0, None, 0
        print('Pretrain mean-field GNN...')
        for i in range(0,self.config.train.pre_train_epochs):
            self.simple_model.train()
            self.simple_optimizer.zero_grad()
            
            loss_subject = (self.x, self.adj, self.y, self.train_mask)
            loss = self.simple_loss_fn(self.simple_model, *loss_subject)
            loss.backward()
            self.simple_optimizer.step()
            
            # Evaluate mean-field GNN
            if i%10 == 0:
                self.simple_model.eval()
                y_est = self.simple_estimator(self.simple_model, self.x, self.adj, self.y, self.train_mask)
                pred = torch.argmax(y_est, dim = -1)
                label = torch.argmax(self.y, dim = -1)
                valid_acc = torch.mean((pred==label)[self.valid_mask].float()).item()
                if valid_acc > best_valid:
                    best_valid = valid_acc
                    best_est = y_est /self.config.diffusion.temp

        print('Done!')

        # Prepare expectation step
        buffer, n_samples, buffer_size = None, 5, 50
        xs, ys, best_ests, masks = [], [], [], []
        for i in range(0,n_samples):
            xs.append(self.x)
            ys.append(self.y)
            best_ests.append(best_est)
            masks.append(self.train_mask)
        xs, ys, masks = torch.cat(xs, dim = 0), torch.cat(ys, dim = 0), torch.cat(masks, dim = 0) # (n_samples*number of data, )
        best_prob = torch.exp(torch.cat(best_ests, dim = 0))
        adjs = self.adjs

        # Train the model
        best_valid, best_test = 0, 0
        for epoch in range(0, self.config.train.num_epochs):
            t_start = time.time()

            # Expectation step
            if epoch % self.config.train.load_interval == 0:
                if epoch > self.config.train.load_start: # Use manifold-constarined sampling of DPM-GSP               
                    expected_y_set = self.mc_estimator(self.model, xs, adjs, ys, masks, temp = self.config.diffusion.temp, coef = self.config.diffusion.coef)
                else: # Use mean-field GNN
                    expected_y_set = torch.distributions.categorical.Categorical(best_prob).sample()
                    expected_y_set = F.one_hot(expected_y_set, best_prob.shape[1]).float()
                
                # Fill the buffer
                expected_y_set = torch.cat(
                        [expected_y_set[i*self.y.shape[0]:(i+1)*self.y.shape[0]].view(1,self.y.shape[0],-1) for i in range(0,n_samples)], dim = 0) # (n_samples, number of data, number of classes)
                if buffer == None:
                    buffer = expected_y_set
                else:
                    buffer = torch.cat([buffer,expected_y_set], dim = 0)

                # Keep the buffer size
                start = buffer.shape[0]-buffer_size
                if start < 0:
                    start = 0
                buffer = buffer[start:]
    
            # Maximization step   
            y_train = buffer[np.random.randint(buffer.shape[0]+1)-1] # Sample from the buffer
            y_train[self.train_mask] = self.y[self.train_mask]   
    
            self.model.train()
            self.optimizer.zero_grad()
            loss_subject = (self.x, self.adj, y_train, self.train_mask, self.config.train.time_batch)
            loss = self.loss_fn(self.model, *loss_subject)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.grad_norm)
            self.optimizer.step()
            if self.config.train.lr_schedule:
                self.scheduler.step()

            # Evaluate the model
            if epoch % self.config.train.print_interval == 0 and epoch > 0:
                
                # Manifold-constrained sampling
                y_est = self.mc_estimator(self.model, self.x, self.adj, self.y, self.train_mask, temp = 0.001, coef = self.config.diffusion.coef)
                pred, label = torch.argmax(y_est, dim = -1), torch.argmax(self.y, dim = -1)
                valid_acc = torch.mean((pred==label)[self.valid_mask].float()).item()
                test_acc = torch.mean((pred==label)[self.test_mask].float()).item()

                if valid_acc >= best_valid:
                    best_valid, best_test = valid_acc, test_acc
                    #torch.save(self.model.state_dict(), 'saved_model/'+self.config.data.data+'_'+self.config.model.model+'_'+str(self.config.model.nhid)+'_'+str(self.config.model.num_layers)+'_'+str(self.config.train.lr)+'_'+str(self.config.train.weight_decay)+'_'+str(self.config.diffusion.temp)+'_'+str(self.config.seed)+'.pt')
                
                # N/A Manifold-constrained sampling
                with torch.no_grad():
                    y_est = self.estimator(self.model, self.x, self.adj, self.y, self.train_mask)
                pred = torch.argmax(y_est, dim = -1)
                train_acc = torch.mean((pred==label)[self.train_mask].float()).item()

                # Log intermediate performance
                logger.log(f'{epoch+1:03d} | val: {valid_acc:.3e} | test: {test_acc:.3e}  | best val: {best_valid:.3e} | best test: {best_test:.3e}', verbose=False)         
                print(f'[Epoch {epoch+1:05d}] | val: {valid_acc:.3e} | test: {test_acc:.3e}  | best val: {best_valid:.3e} | best test: {best_test:.3e}', end = '\r')
        
        print(' ')
