
import torch
import torch.nn as nn
import torch.nn.functional as F

from components.dataset import Batch
from components.vocab import VocabEntry

from grammar.transition_system import ApplyRuleAction, GenTokenAction
from models.ASN import EmbeddingLayer, LuongAttention

class BaselineLocator:
    def __init__(self):
        pass
    
    def predict(self, ex):
        # ex[0] nl-code example, ex[1] a branch
        scores = torch.Tensor([a.score for a in ex[1].actions])
        scores = torch.exp(scores)
        return -scores
    
class GapLocator:
    def __init__(self):
        pass
    
    def predict(self, ex):
        # ex[0] nl-code example, ex[1] a branch
        scores = []
        for d, action in enumerate(ex[1].actions):
            hole = action.hole
            candidate_scores = hole.candidate_scores
            top2, _ = torch.topk(candidate_scores, 2)
            gap = top2[1] - top2[0]
            scores.append(gap)
        scores = torch.stack(scores)
        return scores

# class PeakNextLocator:
#     def __init__(self):
#         pass

#     def predict(self, ex):
#         # ex[0] nl-code example, ex[1] a branch
#         scores = []
#         acc_score = .0
#         for d, action in enumerate(ex[1].actions):
#             hole = action.hole
#             candidate_scores = hole.candidate_scores
#             top2, _ = torch.topk(candidate_scores, 2)
#             scores.append(acc_score + top2[1].item())
#             acc_score += top2[0].item() + 0.25
#         scores = torch.Tensor(scores)
#         return scores

class PeakNextLocator:
    def __init__(self):
        pass

    def next_val(self, val, l):
        l.sort(reverse=True)
        # print(l)
        for i, v in enumerate(l):
            if abs(v - val) < 1e-10:
                if i + 1 <len(l):
                    return l[i + 1]
                else:
                    return float('-inf')


    def predict(self, ex):
        # ex[0] nl-code example, ex[1] a branch
        scores = []
        acc_score = .0
        for d, action in enumerate(ex[1].actions):
            hole = action.hole
            candidate_scores = hole.candidate_scores.tolist()
            next_v = self.next_val(action.score, candidate_scores)
            scores.append(acc_score + next_v)
            # print(action.score, next_v)
            acc_score += action.score + 0.25
        scores = torch.Tensor(scores)
        return scores

class AccLocator:
    def __init__(self):
        pass

    def predict(self, ex):
        # ex[0] nl-code example, ex[1] a branch
        scores = []
        acc_score = .0
        for d, action in enumerate(ex[1].actions):
            score = action.score
            acc_score += score
            scores.append(acc_score)
        scores = torch.Tensor(scores)
        return -scores

class AccHeuristicLocator:
    def __init__(self):
        pass

    def predict(self, ex):
        # ex[0] nl-code example, ex[1] a branch
        scores = []
        acc_score = .0
        for d, action in enumerate(ex[1].actions):
            score = action.score
            acc_score += score + 0.25
            scores.append(acc_score)
        scores = torch.Tensor(scores)
        return -scores

class SeqToSeqLocator(nn.Module):
    def __init__(self, args, parser=None):
        super().__init__()
        self.args = args
        self.parser = parser
        self.dropout = nn.Dropout(0.3)
        self.v_state_size = 100
        self.rnn = nn.LSTM(self.v_state_size, self.v_state_size, num_layers=1, batch_first=True, bidirectional=True)
        # self.ffnn = FeedForwadNN(self.v_state_size, 100, 2)
        self.trans = nn.Linear(200, 100)
        self.ffnn = nn.Linear(300, 2)
        # self.ffnn = nn.Linear(self.v_state_size, 2)
        self.scaling = torch.Tensor([0.2, 0.8])

    def score(self, examples):
        scores = []
        for x in examples:
            scores.append(self._score(x))
        return torch.stack(scores).mean()

    def _score(self, ex):
        activations, labels = self.forward(ex)
        loss = F.cross_entropy(activations, labels, weight=self.scaling,reduction='sum')
        return loss
    
    def pre_index_examples(self, examples):
        new_exs = []
        for ex in examples:
            nl_code, branch = ex
            v_states, labels = self.build_inputs_from_branch(branch)

            with torch.no_grad():
                batch = Batch([nl_code], self.parser.grammar, self.parser.vocab)
                context_vecs, _ = self.parser.encode(batch)
            context_vecs = context_vecs.detach()
            
            new_exs.append((v_states, labels, context_vecs))
        return new_exs

    def build_inputs_from_branch(self, branch):
        v_state_seq = [x.v_state for x in branch.actions]
        v_state_seq = torch.cat([x[0] for x in v_state_seq]).detach(), torch.cat([x[1] for x in v_state_seq]).detach()
        labels = [x.label for x in branch.actions]
        labels = torch.LongTensor([1 if x == 'error' else 0 for x in labels])
        return v_state_seq, labels

    # simplist version
    def forward(self, ex):
        v_states, labels, context_vecs = ex
        # vanilla 
        x, _ = self.rnn(v_states[0].unsqueeze(0))
        x = x.squeeze(0)
        x = self.trans(x)
        # x = v_states[0]
        contexts = self.parser.attn(x.unsqueeze(0), context_vecs).squeeze(0)
        x = torch.cat((x, contexts), dim=1)
        # print(x.size(), v_states[0].size())
        activations = self.ffnn(self.dropout(x))
        return activations, labels

    def predict(self, ex):
        activations, _ = self.forward(ex)
        predictions = F.softmax(activations, dim=1)
        return predictions[:,1]



class ScratchLocator(nn.Module):
    def __init__(self, args, parser=None):
        super().__init__()
        self.args = args
        self.parser = parser
        self.parser.eval()
        self.dropout = nn.Dropout(0.3)
        self.v_state_size = 100

        # self.ffnn = nn.Linear(self.v_state_size, 2)
        self.scaling = torch.Tensor([0.2, 0.8])
        self.action_vocab = self.make_action_vocab()
        self.emb_size = self.v_state_size
        self.rnn = nn.LSTM(self.emb_size, self.emb_size, num_layers=1, batch_first=True, bidirectional=True)
        self.action_emb = EmbeddingLayer(self.emb_size, self.action_vocab.size(), 0.3)
        # self.ffnn = FeedForwadNN(self.v_state_size, 100, 2)
        # self.attn = LuongAttention(2 * self.emb_size, 200)
        # self.ffnn = nn.Linear(2 * self.emb_size + 200, 2)
        self.trans = nn.Linear(2 * self.emb_size, self.v_state_size)
        self.ffnn = nn.Linear(300, 2)

    def make_action_vocab(self):
        grammar = self.parser.grammar
        
        vocab = VocabEntry()
        corpus = []
        for prod in grammar.productions:
            corpus.append(prod.constructor.name)
        for prim_type in grammar.primitive_types:
            if prim_type.name in ['cc', 'csymbl']:
                pvocab = self.parser.vocab.primitive_vocabs[prim_type]
                for w in pvocab.word_to_id:
                    corpus.append(w)
            else:
                corpus.append(prim_type.name)
        
        for w in corpus:
            vocab.add(w)
        return vocab

    def pre_index_examples(self, examples):
        new_exs = []
        for ex in examples:
            nl_code, branch = ex
            v_states, labels = self.build_inputs_from_branch(branch)

            with torch.no_grad():
                batch = Batch([nl_code], self.parser.grammar, self.parser.vocab)
                context_vecs, _ = self.parser.encode(batch)
            context_vecs = context_vecs.detach()
            
            action_words = []
            for b_ac in branch.actions:
                ac = b_ac.action
                if isinstance(ac, ApplyRuleAction):
                    action_words.append(ac.production.constructor.name)
                elif isinstance(ac, GenTokenAction):
                    type_name = ac.type.name
                    if type_name in ['cc', 'csymbl']:
                        action_words.append(ac.token)
                    else:
                        action_words.append(type_name)
            action_idx = [self.action_vocab[w] for w in action_words]
            action_idx = torch.LongTensor(action_idx)
            new_exs.append((v_states, labels, context_vecs, action_idx))
        return new_exs

    def score(self, examples):
        scores = []
        for x in examples:
            scores.append(self._score(x))
        return torch.stack(scores).mean()
    
    def _score(self, ex):
        activations, labels = self.forward(ex)
        loss = F.cross_entropy(activations, labels, weight=self.scaling,reduction='sum')
        return loss
    
    def build_inputs_from_branch(self, branch):
        v_state_seq = [x.v_state for x in branch.actions]
        v_state_seq = torch.cat([x[0] for x in v_state_seq]).detach(), torch.cat([x[1] for x in v_state_seq]).detach()
        labels = [x.label for x in branch.actions]
        labels = torch.LongTensor([1 if x == 'error' else 0 for x in labels])
        return v_state_seq, labels

    # simplist version
    def forward(self, ex):
        v_states, labels, context_vecs, action_idx = ex
        
        action_vec = self.action_emb(action_idx)

        # vanilla 
        x, _ = self.rnn(action_vec.unsqueeze(0))
        x = x.squeeze(0)
        x = self.trans(x)
        x = x + v_states[0]
        # activations = self.ffnn(self.dropout(diff))
        # contexts = self.attn(x.unsqueeze(0) , context_vecs).squeeze(0)
        contexts = self.parser.attn(x.unsqueeze(0) , context_vecs).squeeze(0)
        # x = x.squeeze(0)
        x = torch.cat((x, contexts), dim=1)
        # print(x.size(), v_states[0].size())
        activations = self.ffnn(self.dropout(x))
        return activations, labels

    def predict(self, ex):
        activations, _ = self.forward(ex)
        predictions = F.softmax(activations, dim=1)
        return predictions[:,1]


class SpecificLocator(nn.Module):
    def __init__(self, args, parser=None):
        super().__init__()
        self.args = args
        self.parser = parser
        self.parser.eval()
        self.dropout = nn.Dropout(0.3)
        self.v_state_size = 100

        # self.ffnn = nn.Linear(self.v_state_size, 2)
        self.scaling = [0.2, 0.8]
        self.action_vocab = self.make_action_vocab()
        self.emb_size = self.v_state_size
        self.rnn = nn.LSTM(self.emb_size, self.emb_size, num_layers=1, batch_first=True, bidirectional=True)
        self.action_emb = EmbeddingLayer(300, self.action_vocab.size(), 0.3)
        # self.ffnn = FeedForwadNN(self.v_state_size, 100, 2)
        # self.attn = LuongAttention(2 * self.emb_size, 200)
        # self.ffnn = nn.Linear(2 * self.emb_size + 200, 2)
        # self.trans = nn.Linear(3 * self.emb_size, self.v_state_size)
        self.ffnn = nn.Linear(300, 2)

    def make_action_vocab(self):
        grammar = self.parser.grammar
        
        vocab = VocabEntry()
        corpus = []
        for prod in grammar.productions:
            corpus.append(prod.constructor.name)
        for prim_type in grammar.primitive_types:
            if prim_type.name in ['cc', 'csymbl']:
                pvocab = self.parser.vocab.primitive_vocabs[prim_type]
                for w in pvocab.word_to_id:
                    corpus.append(w)
            else:
                corpus.append(prim_type.name)
        
        for w in corpus:
            vocab.add(w)
        return vocab

    def pre_index_examples(self, examples):
        new_exs = []
        for ex in examples:
            nl_code, branch = ex
            v_states, labels = self.build_inputs_from_branch(branch)

            with torch.no_grad():
                batch = Batch([nl_code], self.parser.grammar, self.parser.vocab)
                context_vecs, _ = self.parser.encode(batch)
            context_vecs = context_vecs.detach()
            
            action_words = []
            for b_ac in branch.actions:
                ac = b_ac.action
                if isinstance(ac, ApplyRuleAction):
                    action_words.append(ac.production.constructor.name)
                elif isinstance(ac, GenTokenAction):
                    type_name = ac.type.name
                    if type_name in ['cc', 'csymbl']:
                        action_words.append(ac.token)
                    else:
                        action_words.append(type_name)
            action_idx = [self.action_vocab[w] for w in action_words]
            action_idx = torch.LongTensor(action_idx)
            new_exs.append((v_states, labels, context_vecs, action_idx))
        return new_exs

    def score(self, examples):
        scores = []
        for x in examples:
            scores.append(self._score(x))
        return torch.stack(scores).mean()
    
    def _score(self, ex):
        activations, labels = self.forward(ex)
        loss = self.binary_classification_loss(activations, labels)
        return loss

    def build_inputs_from_branch(self, branch):
        v_state_seq = [x.v_state for x in branch.actions]
        v_state_seq = torch.cat([x[0] for x in v_state_seq]).detach(), torch.cat([x[1] for x in v_state_seq]).detach()
        labels = [x.label for x in branch.actions]
        labels = torch.LongTensor([1 if x == 'error' else 0 for x in labels])
        return v_state_seq, labels

    # simplist version
    def forward(self, ex):
        v_states, labels, context_vecs, action_idx = ex
        
        action_vec = self.action_emb(action_idx)

        # vanilla 
        x = v_states[0]
        contexts = self.parser.attn(x.unsqueeze(0) , context_vecs).squeeze(0)
        x = torch.cat((x, contexts), dim=1)
        # print(x.size(), v_states[0].size())
        # activations = self.ffnn(self.dropout(x))
        activations = torch.sum(x * action_vec, dim=1)
        activations = torch.sigmoid(activations)
        return activations, labels

    def predict(self, ex):
        predictions, _ = self.forward(ex)
        return predictions

    def binary_classification_loss(self, x, y):
        y = y.float()
        eps = 1e-10
        soft_x = torch.clamp(x, min=eps)
        minus_x = 1 - soft_x
        minus_x = torch.clamp(minus_x, min=eps)
        loss = - y * torch.log(soft_x) * self.scaling [1] - (1 - y) * torch.log(minus_x) * self.scaling[0]

        loss = torch.sum(loss)
        return loss

class FeedForwadNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.3):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.w1 = nn.Linear(input_size, hidden_size)
        self.w2 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.w1(x))
        x = self.dropout(x)
        x = self.w2(x)
        return x
