import torch
import torch.nn as nn
import torch.nn.functional as F

from components.dataset import Batch
from components.vocab import VocabEntry

from grammar.transition_system import ApplyRuleAction, GenTokenAction
# from models.ASN import EmbeddingLayer, LuongAttention
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

import numpy as np

def _make_np_feature(ex, indexer):
    _, branch, features = ex
    labels = [1 if x.label == 'error' else 0 for x in branch.actions]
    labels = np.array(labels)
    new_feats = []
    for single_feat in features:
        val_feat = [.0] * len(indexer)
        for k, v in single_feat.data.items():
            if k in indexer:
                val_feat[indexer[k]] = v
        new_feats.append((val_feat))
    new_feats = np.array(new_feats)
    return new_feats, labels

# class RandomForestLocator:
#     def __init__(self, feature_indexer):
#         self.feat_indexer = feature_indexer
#         self.model = RandomForestClassifier(n_estimators=75, class_weight='balanced', random_state=33)

#     def pre_index_examples(self, examples):
#         new_exs = [_make_np_feature(x, self.feat_indexer) for x in examples]
#         return new_exs

#     def fit(self, examples):
#         all_x = []
#         all_y = []

#         for feats, labels in examples:
#             all_x.append(feats)
#             all_y.append(labels)
            
#         all_x = np.concatenate(all_x, axis=0)
#         all_y = np.concatenate(all_y, axis=0)
#         self.model.fit(all_x, all_y)
    
#     def predict(self, ex):
#         y = self.model.predict_proba(ex[0])
#         return y[:,1]
    
#     def interpretablity(self):
#         feat_imp = self.model.feature_importances_
#         imp_idx = np.argsort(-feat_imp)
#         for i in imp_idx[:20]:
#             print(self.feat_indexer.get_word(i))

class RandomForestLocator:
    def __init__(self, feature_indexer):
        self.feat_indexer = feature_indexer
        self.model = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(200,), learning_rate_init=0.005,verbose=True, random_state=1, max_iter=100)
        # self.model = LogisticRegression(solver='lbfgs', max_iter=25)
    def pre_index_examples(self, examples):
        new_exs = [_make_np_feature(x, self.feat_indexer) for x in examples]
        new_exs = [(a.astype('float32'), b) for (a, b) in new_exs]
        return new_exs

    def fit(self, examples):
        all_x = []
        all_y = []

        for feats, labels in examples:
            all_x.append(feats)
            all_y.append(labels)
            
        all_x = np.concatenate(all_x, axis=0)
        all_y = np.concatenate(all_y, axis=0)
        self.model.fit(all_x, all_y)
    
    def predict(self, ex):
        y = self.model.predict_proba(ex[0])
        return y[:,1]
    
    def interpretablity(self):
        pass

class GradBoostingLocator:
    def __init__(self, feature_indexer):
        self.feat_indexer = feature_indexer
        self.model = AdaBoostClassifier(random_state=33, n_estimators=100)

    def pre_index_examples(self, examples):
        new_exs = [_make_np_feature(x, self.feat_indexer) for x in examples]
        return new_exs

    def fit(self, examples):
        all_x = []
        all_y = []

        for feats, labels in examples:
            all_x.append(feats)
            all_y.append(labels)
            
        all_x = np.concatenate(all_x, axis=0)
        all_y = np.concatenate(all_y, axis=0)
        self.model.fit(all_x, all_y)
    
    def predict(self, ex):
        x = ex[0]
        y = self.model.predict_proba(ex[0])
        return y[:,1]
    
    def interpretablity(self):
        feat_imp = self.model.feature_importances_
        imp_idx = np.argsort(-feat_imp)
        for i in imp_idx[:20]:
            print(self.feat_indexer.get_word(i))

# class SeqToSeqLocator(nn.Module):
#     def __init__(self, args, parser=None):
#         super().__init__()
#         self.args = args
#         self.parser = parser
#         self.dropout = nn.Dropout(0.3)
#         self.v_state_size = 100
#         self.rnn = nn.LSTM(self.v_state_size, self.v_state_size, num_layers=1, batch_first=True, bidirectional=True)
#         # self.ffnn = FeedForwadNN(self.v_state_size, 100, 2)
#         self.trans = nn.Linear(200, 100)
#         self.ffnn = nn.Linear(300, 2)
#         # self.ffnn = nn.Linear(self.v_state_size, 2)
#         self.scaling = torch.Tensor([0.2, 0.8])

#     def score(self, examples):
#         scores = []
#         for x in examples:
#             scores.append(self._score(x))
#         return torch.stack(scores).mean()

#     def _score(self, ex):
#         activations, labels = self.forward(ex)
#         loss = F.cross_entropy(activations, labels, weight=self.scaling,reduction='sum')
#         return loss
    
#     def pre_index_examples(self, examples):
#         new_exs = []
#         for ex in examples:
#             nl_code, branch = ex
#             v_states, labels = self.build_inputs_from_branch(branch)

#             with torch.no_grad():
#                 batch = Batch([nl_code], self.parser.grammar, self.parser.vocab)
#                 context_vecs, _ = self.parser.encode(batch)
#             context_vecs = context_vecs.detach()
            
#             new_exs.append((v_states, labels, context_vecs))
#         return new_exs

#     def build_inputs_from_branch(self, branch):
#         v_state_seq = [x.v_state for x in branch.actions]
#         v_state_seq = torch.cat([x[0] for x in v_state_seq]).detach(), torch.cat([x[1] for x in v_state_seq]).detach()
#         labels = [x.label for x in branch.actions]
#         labels = torch.LongTensor([1 if x == 'error' else 0 for x in labels])
#         return v_state_seq, labels

#     # simplist version
#     def forward(self, ex):
#         v_states, labels, context_vecs = ex
#         # vanilla 
#         x, _ = self.rnn(v_states[0].unsqueeze(0))
#         x = x.squeeze(0)
#         x = self.trans(x)
#         # x = v_states[0]
#         contexts = self.parser.attn(x.unsqueeze(0), context_vecs).squeeze(0)
#         x = torch.cat((x, contexts), dim=1)
#         # print(x.size(), v_states[0].size())
#         activations = self.ffnn(self.dropout(x))
#         return activations, labels

#     def predict(self, ex):
#         activations, _ = self.forward(ex)
#         predictions = F.softmax(activations, dim=1)
#         return predictions[:,1]

