import torch
from pytorch_models.models.rnn.rnn_cells import *
from pytorch_models.utils.utils import *
import torch.nn as nn
from pytorch_models.trainer.base import ModelAdaptor


class SequentialEncoder(nn.Module):
    def __init__(self, in_dim, state_dim, num_layers = 1, rnn_type = "lstm", dropout = 0, hinit = 0):
        super(SequentialEncoder, self).__init__()
        self.in_dim = in_dim
        self.state_dim = state_dim
        self.rnn_type = rnn_type
        self.num_layers = num_layers
        self.hinit = hinit
               
        if rnn_type == "lstm":
            self.rnn = torch.nn.LSTM(in_dim, state_dim, num_layers = num_layers, bias = True, dropout = dropout)
        elif rnn_type == "elman":
            self.rnn = torch.nn.RNN(in_dim, state_dim, num_layers = num_layers, bias = True, dropout = dropout)
        elif rnn_type == "gru":
            self.rnn = torch.nn.GRU(in_dim, state_dim, num_layers = num_layers, bias = True, dropout = dropout)
        else:
            self.rnn = SequentialRNN(in_dim, state_dim, num_layers = num_layers, rnn_type = rnn_type, bias = True, dropout = dropout)

    def get_hidden_init(self, shape, device):
        if "lstm" in self.rnn_type:
            if self.hinit == 1:
                hidden = ( torch.ones(shape, requires_grad=True).to(device),
                    torch.zeros(shape, requires_grad=True).to(device))
            elif self.hinit == 0:
                hidden = ( torch.zeros(shape, requires_grad=True).to(device),
                    torch.zeros(shape, requires_grad=True).to(device))
        else:
            if self.hinit == 1:
                hidden = torch.ones(shape, requires_grad=True).to(device)
            elif self.hinit == 0:
                hidden = torch.zeros(shape, requires_grad=True).to(device)
        return hidden

    def forward(self, inp, inp_len, hidden = None, _last_hidden = False):
        device = inp.device
        n, ll = inp.shape[:2]
        nlb = self.num_layers
        s = self.state_dim

        if hidden is None:
            hidden = self.get_hidden_init([nlb ,n, s], inp.device)
        else:
            if isinstance(hidden, tuple):
                hidden = (hidden[0].contiguous(), hidden[1].contiguous())
            else:
                hidden = hidden.contiguous()
        
        h_list = []
        
        mask = torch.arange(inp.shape[1]).to(device)
        mask= mask[None] == (inp_len-1)[:, None]

        for idx in range(inp.shape[1]):
            _, hidden = self.rnn(inp[:, idx].unsqueeze(dim=0), hidden)
            h_list.append(hidden)


        if isinstance(h_list[0], tuple):
            h,c = zip(*h_list)
            ll = len(h)
            #print(torch.stack(h).shape, torch.stack(h).permute(2, 0, 1, 3).shape)
            h = torch.stack(h).permute(2, 0, 1, 3).reshape(n, ll, -1)
            c = torch.stack(c).permute(2, 0, 1, 3).reshape(n, ll, -1)
            hidden = (h, c)
            #print(h.shape, c.shape)
            #print(inp.shape)
            #print(h.reshape(inp.shape[0], inp.shape[1], -1).shape)
            if _last_hidden:
                last_h = masked_select_vectors(h, mask).reshape(n, nlb, s).transpose(1,0)
                last_c = masked_select_vectors(c, mask).reshape(n, nlb, s).transpose(1,0)
                last_hidden = (last_h, last_c)
        else:
            hidden = torch.stack(h_list).permute(2,0,1,3).reshape(n, ll, -1) 
            if _last_hidden:
                last_hidden = masked_select_vectors(hidden, mask).reshape(n, nlb, s).transpose(1,0)


        if _last_hidden:
            return last_hidden, hidden
        else:
            return hidden


class ClassificationModel(nn.Module):
    def __init__(self, state_dim, num_layers, rnn_type, dropout, vocab_size, hinit, partial_strings = False):
        super(ClassificationModel, self).__init__()
        
        self.in_dim = vocab_size - 1
        self.state_dim = state_dim
        self.num_layers = num_layers
        self.rnn_type = rnn_type
        self.dropout = dropout
        self.vocab_size = vocab_size
        self.partial_strings = partial_strings

        self.rnn = SequentialEncoder(self.in_dim, self.state_dim, self.num_layers, self.rnn_type, self.dropout, hinit=hinit)
        embeddings = torch.cat([torch.zeros([1, self.in_dim]), torch.eye(self.in_dim)])
        self.register_buffer("embeddings", embeddings)
        self.classifier = torch.nn.Linear(self.state_dim, 1)

    def forward(self, input_seq, seq_len):
        device = self.embeddings.device
        # x = torch.nn.functional.embedding(input_seq, self.embeddings)
        # n, ll, _= x.shape
        # hidden = self.rnn(x, seq_len)
        # pre_output = hidden.reshape(n, ll, self.num_layers, self.state_dim)[:, :, -1, :].squeeze()

        pre_output = self.get_hidden(input_seq, seq_len)

        mask = torch.arange(input_seq.shape[1]).to(device)
        if self.partial_strings:
            mask = mask[None] <= (seq_len-1)[:, None]
        else:  
            # final accpetance of the string is one before <END> input
            mask = mask[None] == (seq_len-2)[:, None] 

        pre_output = masked_select_vectors(pre_output, mask)
        #tgt_output = masked_select_vectors(output_seq, mask)
        logits = torch.sigmoid(self.classifier(pre_output))
        return logits, mask

    def get_hidden(self, input_seq, seq_len):
        x = torch.nn.functional.embedding(input_seq, self.embeddings)
        n, ll, _= x.shape
        hidden = self.rnn(x, seq_len)
        #print("hidden : ", hidden[:5,:5])
        if isinstance(hidden, tuple):
            hidden = hidden[0]
        pre_output = hidden.reshape(n, ll, self.num_layers, self.state_dim)[:, :, -1, :]
        return pre_output

    def one_step(self, input, hidden):
        x = torch.nn.functional.embedding(input, self.embeddings)
        hidden = self.rnn(x, torch.tensor([1], device = x.device), hidden = hidden)
        if isinstance(hidden, tuple):
            h = hidden[0]
        else:
            h = hidden
        logits = torch.sigmoid(self.classifier(h))
        return logits, hidden



class ClassificationModelAdaptor(ModelAdaptor):
    def __init__(self, model, device) -> None:
        super().__init__(model)
        self.device = device
    
    def register_loss_function(self, loss_func, l1_lambda = 0):
        self.loss_func = loss_func
        self.l1_lambda = l1_lambda
    
    def calc_loss_and_eval(self, batch_data, iter):
        clogits, y = self.run_model(batch_data)
        loss = self.calc_loss(clogits, y)
        eval = self.evaluate(clogits, y)
        return loss, eval

    def calc_loss_and_eval_on_dataset(self, dataset, iter):
        self.model.eval()
        logits = []
        y_true = []
        total_loss = 0
        with torch.no_grad():
            for batch_data in dataset:
                clogits, y = self.run_model(batch_data)
                loss = self.calc_loss(clogits, y)
                total_loss += loss.detach()
                logits.append(clogits.detach().cpu())
                y_true.append(y.detach().cpu())

        logits = torch.cat(logits, dim = 0)
        y_true = torch.cat(y_true, dim = 0)
        eval = self.evaluate(logits, y_true)
        loss = total_loss/len(dataset)
        self.model.train()
        return loss, eval

    def calc_loss(self, clogits, y):
        loss = self.loss_func(clogits, y)
        l1_norm = sum(torch.linalg.norm(p, 1) for p in self.model.parameters())
        loss = loss + self.l1_lambda * l1_norm
        return loss

    def run_model(self, batch_data):
        input_seq, out_seq, seq_len = [item.to(self.device) for item in batch_data]
        #print(input_seq, out_seq, seq_len)
        clogits, mask = self.model(input_seq, seq_len)
        #tgt_seq = masked_select_vectors(out_seq, mask)
        # print("grad : ", self.models["model"].classifier.weight.grad)
        # print("weight : ", self.models["model"].classifier.weight)
        if out_seq.squeeze().dim() > 1:
            y = torch.masked_select(out_seq, mask)
        else : y = out_seq.type(torch.long)
        return clogits.squeeze(), y.squeeze().type(torch.float)

    def evaluate(self, clogits, y):
        klas = clogits.round().squeeze()
        # print(y[:10], clogits[:10], klas[:10])
        # print(y.shape, clogits.shape, klas.shape)
        cavg = torch.mean((y== klas).type(torch.float))
        return {"classification_avg" : float(cavg.detach())}

    def get_hidden(self, input_seq, seq_len):
        return self.model.get_hidden(input_seq, seq_len)



class ClassificationModelLSTARAdaptor:
    def __init__(self, model, alphabet, vocab, device, start = True, end = True):
        self.model = model
        self.alphabet = alphabet
        self.device = device
        self.vocab = vocab
        self.start_sym = start
        self.end_sym = end

    def _hidden2state(self, hidden):
        if isinstance(hidden, tuple):
            h, c = hidden
            state = torch.cat([h.squeeze(), c.squeeze()], dim = -1).detach().cpu().squeeze().tolist()
        else:
            state = hidden.detach().cpu().squeeze().tolist()
        
        return state

    def _state2hidden(self, state):
        if len(state) == 2*self.model.state_dim:
            h = torch.tensor(state[: self.model.state_dim], device = self.device).reshape(1, 1, self.model.state_dim)
            c = torch.tensor(state[self.model.state_dim :], device = self.device).reshape(1, 1, self.model.state_dim)
            hidden = (h, c)
        else:
            #print(state, len(state), self.model.state_dim)
            assert len(state) == self.model.state_dim
            hidden = torch.tensor(state, device = self.device).reshape(1, 1, self.model.state_dim)
        
        return hidden

    def get_first_RState(self):
        hinit = self.model.rnn.get_hidden_init([1,1, self.model.state_dim], self.device)
        if self.start_sym:
            input = torch.tensor(self.vocab["<START>"], device = self.device).reshape(1, 1)
            logits, hidden = self.model.one_step(input, hinit)
            pred = logits.squeeze().round().item() == 1
        else:
            hidden = hinit  
            pred = True
                    
        state = self._hidden2state(hidden)
        print("init state pred : ", pred)
        return state, pred

    def get_next_RState(self, state, char):
        hidden = self._state2hidden(state)
        input = torch.tensor(self.vocab[char], device = self.device).reshape(1, 1)
        logits, hidden = self.model.one_step(input, hidden)
        pred = logits.squeeze().round().item() == 1
        state = self._hidden2state(hidden)
        #print("chr : ", char, "  pred : ", pred)
        return state, pred

    def classify_word(self, word):
        word = list(word)
        if self.start_sym:
            word = ["<START>"] + word

        if self.end_sym:
            word = word + ["<END>"]

        input = torch.tensor([self.vocab[ch] for ch in word], device = self.device).reshape(1, len(word))
        input_len = torch.tensor([len(word)], device = self.device)
        logits, _ = self.model(input, input_len)
        pred = logits.squeeze().round().item() == 1
        print(word, pred)
        return pred