import torch
from torch.utils.data import Dataset
from scipy.stats import ttest_rel
import copy
from torch.nn.parallel.distributed import DistributedDataParallel


class Baseline(object):

    def wrap_dataset(self, dataset):
        return dataset

    def unwrap_batch(self, batch):
        return batch, None

    def eval(self, x, c):
        raise NotImplementedError("Override this method")

    def get_learnable_parameters(self):
        return []

    def epoch_callback(self, model, epoch):
        pass

    def state_dict(self):
        return {}

    def load_state_dict(self, state_dict):
        pass


class RolloutBaseline(Baseline):

    # datasets is the validation dataset!
    # datasets should be a list of datasets.
    # each element in the list must contain a dataset of graphs of the same number of vertices.
    def __init__(self, model, datasets: list, device, batch_size=64, epoch=0, bl_alpha=0.05):
        super(Baseline, self).__init__()

        self.device = device
        self.bl_alpha = bl_alpha
        self.datasets = datasets
        self.batch_size = batch_size
        self._update_model(model, epoch)

    def _update_model(self, model, epoch, candidate_vals=None):

        if isinstance(model, DistributedDataParallel):
            model = model.module

        self.model = copy.deepcopy(model)

        if candidate_vals is None:
            self.baseline_vals = torch.cat([self.model.rollout(data, batch_size=self.batch_size,
                                                               graph_nodes=data.graph_nodes,
                                                               device=self.device) for data in self.datasets], 0)
        else:
            self.baseline_vals = candidate_vals

        self.mean = self.baseline_vals.mean().item()
        self.epoch = epoch

    def wrap_dataset(self, dataset):
        rolled_out = self.model.rollout(dataset, batch_size=self.batch_size, graph_nodes=dataset.graph_nodes,
                                        device=self.device)
        return BaselineDataset(dataset, rolled_out.view(-1, 1))

    def unwrap_batch(self, batch):
        return batch

    def eval(self, x, c):
        with torch.no_grad():
            v, _ = self.model(x)

        return v, 0

    def epoch_callback(self, model, epoch):

        if isinstance(model, DistributedDataParallel):
            model = model.module

        # candidate_vals = model.rollout(self.datasets, batch_size=self.batch_size, graph_nodes=self.datasets.graph_nodes, device=self.device)
        candidate_vals = torch.cat([model.rollout(data, batch_size=self.batch_size,
                                                       graph_nodes=data.graph_nodes, device=self.device)
                                    for data in self.datasets], 0)

        candidate_mean = candidate_vals.mean().item()

        if __debug__:
            print("Epoch {} candidate mean {}, baseline epoch {} mean {}, difference {}".format(
                epoch, candidate_mean, self.epoch, self.mean, candidate_mean - self.mean))

        if candidate_mean - self.mean < 0:
            t, p = ttest_rel(candidate_vals.cpu(), self.baseline_vals.cpu())

            p_val = p / 2
            assert t < 0, "T-statistic should be negative"
            if __debug__:
                print("p-value: {}".format(p_val))
            if p_val < self.bl_alpha:
                if __debug__:
                    print('Update baseline')
                self._update_model(model, epoch, candidate_vals)
                return True, candidate_mean

        return False, candidate_mean

    def state_dict(self):
        return {
            'model': self.model,
            'dataset': [],
            'epoch': self.epoch
        }

    def load_state_dict(self, state_dict):
        load_model = copy.deepcopy(self.model)
        load_model.load_state_dict(state_dict['model'].state_dict())
        self._update_model(load_model, state_dict['epoch'])


class BaselineDataset(Dataset):

    def __init__(self, dataset=None, baseline=None):
        super(BaselineDataset, self).__init__()

        self.dataset = dataset

        self.baseline = baseline
        assert (len(self.dataset) == len(self.baseline))

    def __getitem__(self, item):
        return self.dataset[item], self.baseline[item]

    def __len__(self):
        return len(self.dataset)
