import torch
import os
import logging
import pickle
import math

class ModelHandler(object):
    def __init__(self, model, optim, dataset, device, save_location, 
        model_name="best_val.model"):
        self.model = model
        self.optim = optim
        self.dataset = dataset
        self.device = device
        self.save_location = save_location
        self.model_name = model_name

    def load_model(self, model_file, batch_size):
        self.model = torch.load(open(model_file, "rb"))
        val_batches = self.dataset.num_batches(batch_size,'val')
        with torch.no_grad():
            val_loss = 0
            for j in range(val_batches):
                x = self.dataset.get_batch(batch_size, 'val').to(self.device)
                logits = self.model(x)
                scores = torch.gather(input=logits[:, :-1], 
                    index=x[:,1:].unsqueeze(-1), dim=2).squeeze(-1)
                pad_mask = (x[:,1:] != self.dataset.pad_idx).long().to(
                    self.device)
                loss= -1*torch.sum(torch.sum(pad_mask * scores,dim=1),dim=0)
                val_loss += loss.item()/batch_size/val_batches
            print("Validation Loss: {:.4f}".format(val_loss))
        self.best_val_loss = val_loss

    def train_model(self, batch_size, epochs, test_method="beam_search", 
        beam_width=5):
        num_batches = self.dataset.num_batches(batch_size,'train')
        val_batches = self.dataset.num_batches(batch_size,'val')
        self.best_val_loss = float('inf')
        for i in range(epochs):
            epoch_loss = 0
            one_fifth = int(0.2*num_batches + 1)
            for j in range(num_batches):
                self.optim.zero_grad()
                x = self.dataset.get_batch(batch_size).to(self.device)
                logits = self.model(x)
                scores = torch.gather(input=logits[:, :-1], 
                    index=x[:,1:].unsqueeze(-1), dim=2).squeeze(-1)
                pad_mask = (x[:,1:] != self.dataset.pad_idx).long().to(
                    self.device)
                loss= -1*torch.sum(torch.sum(pad_mask * scores,dim=1),dim=0)
                epoch_loss += loss.item()/batch_size/num_batches
                loss.backward()
                #for n, p in self.model.named_parameters():
                #    print(n, p.grad)
                self.optim.step()
                if math.isnan(loss):
                    print(loss)
                    input()
                if (j+1) % one_fifth == 0:
                    print("{}/{} batches".format(j, num_batches))
                    #print(epoch_loss)
            print("EPOCH {}: {:.4f}".format(i, epoch_loss))
            with torch.no_grad():
                val_loss = 0
                for j in range(val_batches):
                    x = self.dataset.get_batch(batch_size, 'val'
                        ).to(self.device)
                    logits = self.model(x)
                    scores = torch.gather(input=logits[:, :-1], 
                        index=x[:,1:].unsqueeze(-1), dim=2).squeeze(-1)
                    pad_mask = (x[:,1:] != self.dataset.pad_idx).long().to(
                        self.device)
                    loss= -1*torch.sum(torch.sum(pad_mask * scores,dim=1),dim=0)
                    val_loss += loss.item()/batch_size/val_batches
                print("Validation Loss: {:.4f}".format(val_loss))
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    torch.save(self.model,
                        os.path.join(self.save_location, self.model_name))
            if (i+1)%100 == 0:
                test_acc, results = self.test_model(test_method, beam_width, 
                    record_results=True)
                logging.info('{}: Epoch {}: Accuracy: {:.2f}'.format(
                    self.model_name, i, test_acc))
                pickle.dump(results, open(
                    'results/{}_epoch_{}_results.pickle'.format(
                        self.model_name, i),
                    'wb'))
            if (i+1)%100 == 0:
                self.test_model(test_method, beam_width, split="train", 
                    portion=0.01)

    def test_model(self, method, beam_width=5, split="test", 
        portion=1, record_results=False):
        max_length = self.dataset.max_out_len + 1
        correct = 0
        total = 0
        ood_total = 0
        ood_correct = 0
        if record_results:
            results = {}
        one_fifth = int(1+0.2*portion * len(self.dataset.tensor_data[split]))
        for j in range(int(portion * len(self.dataset.tensor_data[split]))):
            if (j+1)%one_fifth == 0 and total > 0:
                print("{}/{} batches".format(j, 
                    int(portion * len(self.dataset.tensor_data[split]))))
                print("Running accuracy: {:.2f}".format(correct/total))
            x_in, y_out, idx = self.dataset.get_batch(1, split, True, 
                with_idx=True)
            if method == "greedy":
                with torch.no_grad():
                    y = self.greedy(x_in, max_length)
            elif method == "beam_search":
                with torch.no_grad():
                    y = self.beam_search(x_in, max_length, beam_width)
            ood = not self.dataset.is_in_train_distribution(
                self.dataset.all_data[split][idx.item()])
            is_correct = False
            if (y.shape[1] == y_out.to(self.device).shape[1] 
                and torch.all(y == y_out.to(self.device))):
                correct += 1
                is_correct = True
                if ood:
                    ood_correct += 1
            total += 1
            if ood:
                ood_total += 1
            if record_results:
                info = self.dataset.all_data[split][idx.item()]
                results[idx.item()] = {'is_correct':is_correct, 
                "input":info['input'], 
                "correct_output":info['correct_output'],
                "seq_length":info['seq_length'], 
                "total_length":info['total_length'], 
                "dep_length_end":info['dep_length_end'],
                "ood":ood}
        print("{} accuracy: {:.2f}".format(split.capitalize(),correct/total))
        if ood_total > 0:
            print("{} OOD accuracy: {:.2f}".format(
                split.capitalize(),ood_correct/ood_total))
            if total - ood_total>0:
                print("{} in-dist accuracy: {:.2f}".format(split.capitalize(),
                    (correct-ood_correct)/(total-ood_total)))
        if record_results:
            return correct/total, results
        else:
            return correct/total

    def greedy(self, x_in, max_length):
        x = x_in.to(self.device)
        y = torch.LongTensor([]).to(self.device)
        next_tok = torch.LongTensor([-1])
        while next_tok.item() != self.dataset.vocab["2"]:
            logits = self.model(x)
            vals, max_args = torch.max(logits, dim=2)
            next_tok = max_args[:, -1]
            x = torch.cat([x, next_tok.unsqueeze(0)], dim=1)
            y = torch.cat([y, next_tok.unsqueeze(0)], dim=1)
            if y.shape[1] > max_length:
                break
        return y

    def beam_search(self, x_in, max_length, beam_width):
        x = x_in.to(self.device)
        y = torch.LongTensor([]).to(self.device)
        next_tok = torch.LongTensor([-1])
        beam = [{"x" : x_in.clone().to(self.device), 
                "y" : torch.LongTensor([]).to(self.device), 
                "next_tok" : torch.LongTensor([-1]), 
                "score" : 0, "done":False} for b in range(beam_width)]
        new_beam = []
        done = False
        while not done:
            for possible in beam:
                if possible["done"]:
                    new_beam.append(possible)
                else:
                    logits = self.model(possible["x"])
                    vals, max_args = torch.topk(logits[:,-1], beam_width, dim=1)
                    for i in range(beam_width):
                        next_tok = max_args[:, i]
                        new_x = torch.cat([possible["x"], next_tok.unsqueeze(0)], 
                            dim=1)
                        new_y = torch.cat([possible["y"], next_tok.unsqueeze(0)], 
                            dim=1)
                        score = possible["score"] + vals[0,i].item()
                        if (next_tok.item() == self.dataset.vocab["2"] 
                            or new_y.shape[1] > max_length):
                            new_done = True
                        else:
                            new_done = False
                        new_beam.append({"x" : new_x, "y" : new_y, 
                            "next_tok" : next_tok, "score" : score, 
                            "done":new_done} )
            new_beam.sort(reverse=True, key=lambda x:x['score'])
            beam = new_beam[:beam_width]
            new_beam = []
            done = all([b['done'] for b in beam])
        return beam[0]['y']