import torch
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.metrics import r2_score
import random
import matplotlib as mpl
import os
import gc
import pandas as pd
import csv
from numpy import *
from torch.utils.tensorboard import SummaryWriter
from datetime import date
import time
import builtins
from sklearn.metrics import balanced_accuracy_score, confusion_matrix


# amino_acid = np.load('./model/categorical_variables.npy', allow_pickle=True)
# amino_acid = amino_acid.tolist()

   
class dataset(Dataset) :
    def __init__(self,ohe, classes,seq_len,output, n_samples) :
        # data loading
        self.ohe = torch.from_numpy(ohe.astype(np.float32))
        self.seq_len = torch.from_numpy(seq_len.astype(int64))
        self.classes = torch.from_numpy(classes.astype(int64)) 
        self.output = torch.from_numpy(output.astype(np.float32)).reshape((n_samples,1))
        self.n_samples = n_samples
        
        
    def __getitem__(self,index) :
        return self.ohe[index], self.classes[index], self.seq_len[index], self.output[index]

    def __len__(self):    
        return self.n_samples      

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.03, max_len: int = 600):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        if d_model%2 != 0:
            div_term = torch.exp(torch.arange(0, d_model+1, 2) * (-math.log(10000.0) / d_model))
            pe = torch.zeros(max_len, 1, d_model+1)
        else:
            div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
            pe = torch.zeros(max_len, 1, d_model)
            
        position = torch.arange(max_len).unsqueeze(1)
        # div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        # pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        
        if d_model%2!=0:
            pe = pe[:,:,0:-1]
        self.register_buffer('pe', pe)

    def forward(self, x, rank):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        # print(x.size(), self.pe[:x.size(0)].size())
        x = x + self.pe[:x.size(0)].to(rank)
        return self.dropout(x)

def initial_q(grad):
        # grad = torch.mean(grad, dim=1)
        # return grad
        initial_out.append(grad)

def initial_d(grad):
        # grad = torch.mean(grad, dim=1)
        # return grad
        initial_pool.append(grad)

def extract(grad):
        # grad = torch.mean(grad, dim=1)
        # return grad
        cam.append(grad)
          
class complor_network(nn.Module):
    def __init__(self, num_classes, d_model, d_out, max_m, rank):
        super(complor_network, self).__init__()
        self.max_m = max_m
        self.rank = rank
        self.d_model = d_model ## equivalent to q
        self.d_out = d_out ## equivalent to d
        # self.d

        ##n layers
        # cnn layers   
        self.positional_encoding = PositionalEncoding(self.d_model)
        self.cnn1 = nn.Sequential( nn.Conv1d(self.d_model,64,1, stride=1), 
                                   nn.ReLU(),
                                   nn.Conv1d(64,128,1, stride=1),
                                   nn.ReLU(),
                                   nn.Conv1d(128,int(self.d_out),1,stride=1),
                                   ) ##16 size motifs

        
        self.maxpool1 = nn.Sequential( nn.AvgPool1d(1, stride=1),
                                       nn.AvgPool1d(1, stride=1),
                                       nn.AvgPool1d(1, stride=1),
                                   )
        
        self.motif_size_1 = 1*1

        self.nn = nn.Sequential(
                                # nn.Linear(int(self.d_model*self.d_out*self.max_m),32),
                                nn.Linear(int(self.d_model*self.d_out*self.max_m),num_classes),
                                # nn.ReLU(),
                                # nn.Linear(32,8),
                                # nn.ReLU(),
                                # nn.Linear(8,num_classes),
 
                                )
        
    def make_p(self, out, sax, seq_len,m):
        cnn_net = getattr(self, f'cnn{m}')
        pool_net = getattr(self, f'maxpool{m}')
        # ln_net = getattr(self, f'ln{m}')
        mo_size = getattr(self, f'motif_size_{m}')
        
        q =  cnn_net(out) ##[N,f,L]  
        # q = torch.permute(ln_net(torch.permute(q,(0,2,1))),(0,2,1))
        
        store_q = q
        
        ## changing pooling
        d = pool_net(sax)#*mo_size ## [N,f,L]
        
        for i in range(out.size(0)):
            d[i,...] = d[i,...]/seq_len[i]  
        
        q = torch.matmul(d, q.permute(0,2,1))
        q = q.reshape((q.size(0), q.size(1), q.size(2),1)) 
        
        
        return store_q, d, q
        
        
    def forward(self, x, seq_len):
        'x: [batch, seq_len, feature], classes: [N,L]'
        
        out = x
        out = torch.permute(out,(0,2,1)) ## making it (N, f, L)
        _,_,out_1 = self.make_p(out, x.permute(0,2,1), seq_len,1)
        heat_map = out_1     
        heat_map = heat_map.permute(0,2,3,1)
        heat_map = nn.Flatten()(heat_map) ## removing amino acid contribution from heat map as there are none        
        heat_map = self.nn(heat_map)
      
        return heat_map
    
            
    def assigning_importance(self,mo_level, kernel_size,unwrapped_len):   
        reduced_len = mo_level.size(-1)
        sequence_importance = torch.zeros((mo_level.size(0), unwrapped_len)).to(self.rank)
        
        for i in range(reduced_len):
            sequence_importance[:,i:(i+kernel_size)] += \
                mo_level[:,i].unsqueeze(-1)
        
        return sequence_importance
        
    def sum_subsequent_n(self,temp, n):
        if n > len(temp):
            raise ValueError("n should be less than or equal to the length of temp")
        
        windowed_view = np.lib.stride_tricks.sliding_window_view(temp, n)
        
        return windowed_view.sum(axis=1)   
    
    def calculate_motif_level(self,dp,m_i):
        d_comp = getattr(self, f'pool_{m_i}') #self.pool_1[:,q,:]
        q_comp = getattr(self, f'out_{m_i}')
        # m_size = initial_cam.size(1)-(d_comp.size(-1)-1)
    
        total = d_comp.size(0)
        max_len = d_comp.size(-1)
        q_id = q_comp.size(1)
        d_id = d_comp.size(1)
        
        all_motif_importance = torch.zeros((total, max_len)).to(self.rank)
        # temp = torch.zeros((d_comp.size(-1))).to(self.rank)
        for i in range(q_id): # i is related to size of latent rep
            for j in range(d_id): # j in related to num of categorical var
                for prot in range(total): # prot defines the protein number in a batch
                    var_imp = dp[prot,j,i].unsqueeze(-1)
                    var_sign = torch.sign(var_imp)
                    # q_yp = initial_cam[prot,i,:].squeeze()
                    # d_yp = initial_pool[prot,j,:].squeeze()
                    # temp_short = self.sum_subsequent_n(temp.to('cpu'), m_size)
                    # temp_short = torch.from_numpy(temp_short).to(self.rank)
                    l = int(self.seq_len[prot])
                    if l == 0:
                        continue  # skip empty sequences
                    # temp = abs(var_sign*d_comp[prot,j,0:l]*q_comp[prot,i,0:l]) # taking abs because the magnitude matters
                    temp = nn.ReLU()(var_sign*d_comp[prot,j,0:l]*q_comp[prot,i,0:l]) # taking abs because the magnitude matters
                    temp = (temp-torch.min(temp))/(torch.max(temp)-torch.min(temp)+1E-18)
                    # temp[:] = (temp)/(torch.max(temp)+1E-18)
                    all_motif_importance[prot,0:l] += temp*abs(var_imp)

                    # del temp, var_imp
                    # gc.collect()
                # print('Done Proteins')
        #     print('Done j')
        # print('Done i')
                    # torch.cuda.empty_cache()

        return all_motif_importance
    
    def forward_motif_importance(self, x,  seq_len):
        'x: [batch, seq_len, feature], classes: [N,L]'
        global cam
        global initial_out, initial_pool
        cam = []
        initial_out = []
        initial_pool = []
        self.Lin = x.size(1)
        self.seq_len = seq_len
        # self.trace_visitation = trace_visitation
        # self.importance =  importance
        # self.overall_imp_segments = overall_imp_segments
        out = x
        
        
        out = torch.permute(out,(0,2,1)) ## making it (N, f, L)
        
        self.out_1, self.pool_1,p_1 = self.make_p(out,x.permute(0,2,1),seq_len,1)
        self.out_1.register_hook(initial_q)
  
        heat_map = p_1
        heat_map.register_hook(extract)
    
        heat_map = heat_map.permute(0,2,3,1)    

        
        heat_map = nn.Flatten()(heat_map) ## removing amino acid contribution from heat map as there are none        
        heat_map = self.nn(heat_map)


        
        return heat_map, cam,initial_out,None