import numpy as np
from tqdm import tqdm
import os,pickle,time,datetime
import torch
import torch.nn.functional as F
from dataset.data_gan import MyDataset
from prepare_data.construct_graphs_from_json import ntype
from discriminators.AttentionDiscriminator import Discriminator
from generators.masked_generator import Generator
from torch_geometric.loader import DataLoader as GraphDataLoader

import utils

class GANModel(object):
    '''
    Class that initializes and owns Generator and Discriminator models. 
    
    '''
    def __init__(self, num_nodes, node_feature_dim, node_info, config):
        # Model configurations.
        self.net_type = config.net_type
        # Training configurations.
        self.batch_size = config.bsz
        self.num_iters = config.num_iters
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout = config.dropout
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.weight_cliping_limit = 0.1
        self.gp_weight =10
        self.d_attention = config.d_attention
        self.steps_per_validation = 100
        self.task = config.task
        self.use_gamma_generator = config.use_gamma_generator
        self.gamma = 0

        if config.gpu < 0:
            self.device = 'cpu'
        else:
            torch.cuda.set_device(config.gpu)
            self.device = 'cuda'

        # Directories.
        self.log_dir = config.log_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        self.node_feature_dim = node_feature_dim
        self.num_nodes = num_nodes
        self.build_model()



    def build_model(self):
        """Create a generator and a discriminator."""

        self.G = Generator(device=self.device, net_type=self.net_type, feature_dim=self.node_feature_dim,
                           batch_size=self.batch_size, num_layers=4,
                           h_dim=32, emb_dim=32,  use_embedding=True, dataset=self.task)
        self.D = Discriminator(num_nodes = self.num_nodes, num_layers = 4,h_dim=64,emb_dim=64, node_feature_dim = self.node_feature_dim,
                               device=self.device, net_type=self.net_type, dropout_rate=self.dropout, attention_mech=self.d_attention, use_embedding=True,last_layer_sigmoid=True)
                               
        self.g_optimizer = torch.optim.RMSprop(list(self.G.parameters()),self.g_lr)
        self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), self.d_lr)
        
        self.G.to(self.device)
        self.D.to(self.device)

    
    def to(self, device):
        '''
        Move models to device
            device: str name of device, as in pytorch models.
        '''
        self.D.to(device)
        self.G.to(device)


    def train(self,dataloader, valloader, epochs = 100):
        '''
            Trains model for self.num_iters 
                dataloader: pytorch geometric dataloader with training data
                valloader: pytorch geometric dataloader with validation data
                
        '''
        
        epoch = 0
        train_step = 0
        while epoch < epochs:
            epoch += 1
            epoch_length = len(dataloader)
            pbar = tqdm(dataloader, total=epoch_length)

            for epoch_step, graph_batch in enumerate(pbar):
                
                if train_step % self.model_save_step == 0:
                    self.save_model(name = str(train_step))
                
                ########## Model Validation ########### 
                self.validate_with_data_generator_noise(valloader, train_step)
                train_step +=1
                self.d_optimizer.zero_grad()

                graph_batch = graph_batch.to(self.device)
                ########## Train the discriminator ###########
                
                d_loss, g_penalty = self.train_discriminator(graph_batch)

                ########## Train the generator ###########
                g_loss = self.train_generator(epoch_step, graph_batch)

                loss_dict = {"g_loss":g_loss,"d_loss":d_loss,"d_loss_pred":d_loss-g_penalty, "d_loss_gpenalty":g_penalty,"gamma":self.gamma}
                print(loss_dict)
                pbar.set_postfix(loss_dict)
        
        print("Done Training!")


    def train_generator(self, epoch_step, graph_batch):
        '''
        One training step for Generator
            epoch_step: step in current epoch
            graph_batch: One batch from pytorch geometric dataloader
            
        '''

        self.g_optimizer.zero_grad()
        if not hasattr(self,'g_loss'):
            g_loss = torch.tensor(0)
        
        
        if (epoch_step+1) % self.n_critic == 0:
            #No need to keep gradients for discriminator
            for p in self.D.parameters():
                p.requires_grad = False
            for p in self.G.parameters():
                p.requires_grad = True

            #Run Generator, compute loss and optimize
            nodes_logit = self.G(x = graph_batch.x, edge_index = graph_batch.edge_index)
            output_fake, logits_fake, att_vec_fake = self.D(x = nodes_logit, edge_index = graph_batch.edge_index, batch=graph_batch.batch)
            if self.use_gamma_generator:
                g_loss = torch.square(torch.mean(output_fake) - self.gamma)
            else:
                g_loss = -torch.mean(output_fake)
            g_loss.backward()
            self.g_optimizer.step()
        return g_loss.detach().cpu().item()

    def train_discriminator(self, graph_batch):
        '''
        One training step for Discriminator
            graph_batch: One batch from pytorch geometric dataloader
        
        Returns:
            Discriminator loss (including GP) 
            gradient penalty value 
            
        '''
        #No need to keep gradients for generator
        for p in self.D.parameters():
            p.requires_grad = True
        for p in self.G.parameters():
            p.requires_grad = False
        
        # Run models
        output_real,logits_real, att_vec_real = self.D(x = graph_batch.x, edge_index = graph_batch.edge_index, batch=graph_batch.batch)
        nodes_logit = self.G(x = graph_batch.x, edge_index = graph_batch.edge_index)
        output_fake,logits_fake,att_vec_fake = self.D(x = nodes_logit.detach(), edge_index = graph_batch.edge_index, batch=graph_batch.batch)

        #Calculate Discriminator loss and optimize. 
        d_loss= torch.mean(output_fake)-torch.mean(output_real) 
        gpenalty = self._gradient_penalty(graph_batch.x,nodes_logit, edge_index = graph_batch.edge_index, batch=graph_batch.batch)
        d_loss+=gpenalty
        g_penalty = gpenalty.detach().cpu().item()
        d_loss.backward()

        #calculate gamma
        scores = output_fake.detach().view(-1,1)
        self.gamma = utils.percentile(scores,15)

        print('Max:',scores.max(),"Min:", scores.min(), 'Gamma:',self.gamma)

        self.d_optimizer.step()
        return d_loss.detach().cpu().item(),g_penalty

    
    def validate_with_data_generator_noise(self, valloader, train_step):
        '''
        Runs validation on model with real data, samples from generator and uniform noise. 

            valloader: pytorch geometric data loader with real validation data
            train_step: train step to check if should run     
        
        Results are printed for logging
        '''
        if train_step % self.steps_per_validation == 0:
            val_scores = {}
            for validation_mode in ['real','anomalies','generator','noise']:
                score = self.validate(valloader,mode = validation_mode)
                val_scores[validation_mode]=score
            val_scores['gamma']= self.gamma
            print("validation scores :",val_scores)



    def validate(self,valloader, mode = "real"):
        '''
        Runs the model on some data without gradients and computes mean score.
            valloader: pytorch geometric dataloader with validation data
            mode: [real,generator,real+noise,noise]  one of these strings indicating what data to run on.
        
        Returns:
            Mean discriminator score for data. 
        '''
        dataloader = valloader
        episode_length = len(dataloader)
        pbar = tqdm(dataloader, total=episode_length)
        self.D.eval()
        self.G.eval()
        all_outputs = []
        
        with torch.no_grad(): #no need for gradients here
            for bidx, graph_batch in enumerate(pbar):
                
                graph_batch = graph_batch.to(self.device)
                if mode == 'real': 
                    normal_inds = graph_batch.y==0
                    output,logits_real, att_vec_real = self.D(x = graph_batch.x, edge_index = graph_batch.edge_index, batch=graph_batch.batch)
                    output = output[normal_inds]
                if mode == 'anomalies':
                    novel_inds = graph_batch.y==1
                    output,logits_real, att_vec_real = self.D(x = graph_batch.x, edge_index = graph_batch.edge_index, batch=graph_batch.batch)
                    output = output[novel_inds]
                    
                if mode =='generator':
                    nodes_logit = self.G(activation=None, x = graph_batch.x, edge_index = graph_batch.edge_index)
                    output,logits_fake,att_vec_fake = self.D(x = nodes_logit.detach(), edge_index = graph_batch.edge_index, batch=graph_batch.batch)
                    
                if mode =='real+noise':
                    output,logits_real, att_vec_real = self.D(x = graph_batch.x + torch.rand_like(graph_batch.x).to(self.device), edge_index = graph_batch.edge_index, batch=graph_batch.batch)
                if mode =='noise':
                    output,logits_real, att_vec_real = self.D(x =torch.rand_like(graph_batch.x).to(self.device), edge_index = graph_batch.edge_index, batch=graph_batch.batch)
                all_outputs.append(output)
        
        all_outputs = torch.cat(all_outputs, dim =0).detach()
        #print(mode,output)
        self.D.train()
        self.G.train()
        return float(torch.mean(all_outputs).cpu().numpy())


    def _gradient_penalty(self, real_data, generated_data, edge_index, batch):
        '''
        Calculates weighted gradient penalty for discriminator

            real_data: nxf feature array for real batched graph
            generated_data: nxf feature array for generated batched graph
            edge_index: 2xm array of edges
            batch: nx1 graph index for each node in batch
        
        Returns:
            Weighted gradient penalty value. 
        '''
        batch_size = 1 # this is because all graphs are batched together. 

        # Calculate interpolation
        alpha = torch.rand_like(real_data).to(self.device)
        interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
        interpolated = torch.autograd.Variable(interpolated, requires_grad=True).to(self.device)

        # Calculate probability of interpolated examples
        outputs,logits,att_vec = self.D(x = interpolated, edge_index = edge_index, batch = batch)
        prob_interpolated = outputs

        # Calculate gradients of probabilities with respect to examples
        gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                        grad_outputs=torch.ones(prob_interpolated.size()).to(self.device),
                                        create_graph=True, retain_graph=True)[0]
                                        

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        gp = self.gp_weight * ((gradients_norm - 1) ** 2).mean()

        return gp

    def save_model(self,name = ""):
        '''
        Saves models on designated directory
            name: str appended after model type for filename.
        '''
        print(f"Saving model {name}")
        torch.save(self.G.state_dict(), os.path.join(self.model_save_dir,f"generator_{name}.pth"))
        torch.save(self.D.state_dict(), os.path.join(self.model_save_dir,f"discriminator_{name}.pth"))
        return

    def load_model(self,name = ""):
        '''
        Loads models
            name: str appended after model type
        '''
        #self.G.load_state_dict(torch.load(os.path.join(self.model_save_dir,f"generator_{name}.pth"),map_location='cpu'))
        self.D.load_state_dict(torch.load(os.path.join(self.model_save_dir,f"discriminator_{name}.pth"),map_location='cpu'))

        return
    
    def test(self, dataloader):
        '''
        Runs model on test data
            dataloader: pytorch geometric dataloader with test data and labels
        
        Returns:
            model predictions on data
            attention vector over nodes in each graph
            data labels from dataloader
        '''
        self.D.eval()
        self.D.to(self.device)
        episode_length = len(dataloader)
        pbar = tqdm(dataloader, total=episode_length)
        all_logits = []
        all_att_vec = []
        all_labels = []
        
        all_labels = []
        for bidx, batch_data_list in enumerate(pbar):
            batch_data_list.to(self.device)
            outputs,logits_real,att_vec = self.D(x = batch_data_list.x, edge_index = batch_data_list.edge_index, batch=batch_data_list.batch)
            labels = batch_data_list.y.cpu().detach()
            all_logits.append(outputs.detach().cpu())
            all_att_vec.append(att_vec.detach().cpu())
            
            all_labels.append(labels)
        all_logits = torch.cat(all_logits, dim = 0)
        all_att_vec = torch.cat(all_att_vec, dim = 0)
        
    
        all_labels = torch.cat(all_labels, dim = 0) 
        return all_logits.detach().cpu().numpy(), all_att_vec.detach().cpu().numpy(), all_labels.detach().cpu().numpy()
     
     
