import torch
import torch.nn as nn
import numpy as np

from einops import rearrange
from torch.utils.data import DataLoader, TensorDataset
from utils.common import (
    get_true_rolled,
    get_roll_mask,
    standardize_window,
    unstandardize_window,
)
from utils.dataset import TimeSeriesDataset

from pulse.augment import DynamicAugmentations
from pulse.timeVarying import TimeVaryingModule
from pulse.reconstruct import ReconstructionNet
from pulse.initialCondition import InitConditionEncoder, SharedInitConditionEncoder
from utils.common import shift_and_mask, get_pred_true
from trainers.base import BaseTrainer

from torch.utils.data import Dataset


class PULSEOracleTrainer(BaseTrainer):
    def __init__(
        self,
        config,
        train_data,
        val_data,
        train_labels,
        val_labels,
    ):

        self.train_labels = train_labels
        self.val_labels = val_labels
        super(PULSEOracleTrainer, self).__init__(config, train_data, val_data)

        self.context_norm = config.encoder_args.norm_last_layer
        self.init_norm = config.model_args.init_args.init_norm
        self.standardize_batch = config.training_args.standardize_batch

        if config.model_args.shared_f_init:
            self.init_encoder = SharedInitConditionEncoder(config)
        else:
            self.init_encoder = InitConditionEncoder(config)

        self.recon_net = ReconstructionNet(config)
        self.aug = DynamicAugmentations(config)
        self.dropout = (
            nn.Dropout1d(self.config.model_args.dropout_rate)
            if self.config.model_args.dropout_rate > 0
            else nn.Identity()
        )  # dropout along time dimension

        if (
            config.model_args.time_vary_args.include
        ):  # option is available only for dynamics
            self.tv_module = TimeVaryingModule(config)

        self.all_modules = {
            "encoder": self.encoder,
            "init_encoder": self.init_encoder,
            "recon_net": self.recon_net,
            "aug": self.aug,
            "tv_module": (
                self.tv_module if config.model_args.time_vary_args.include else None
            ),
        }

        self.model = nn.ModuleDict(self.all_modules)
        self.model.to(self.config.device)

    def setup_dataloader(self, data, labels, train):

        ixs = np.arange(len(labels))
        dataset = TensorDataset(
            torch.from_numpy(data).to(torch.float),
            torch.from_numpy(labels).to(torch.long),
            torch.from_numpy(ixs).to(torch.long),
        )

        loader = DataLoader(
            dataset,
            batch_size=self.config.training_args.batch_size,
            shuffle=train,
            num_workers=torch.get_num_threads(),
        )
        return loader

    def get_timevarying(self, context):
        if self.config.model_args.time_vary_args.include:
            tv, dtv = self.tv_module(context)
        else:
            tv, dtv = (None, None)

        return tv, dtv

    def run_one_batch(self, batch, sample_init=False):

        batch, labels, ix = batch

        batch = batch.to(self.config.device)
        batch_ = batch.clone()

        pairs_ixs = self.get_pairs_ixs(ix)

        assert (self.train_labels[ix] == self.train_labels[pairs_ixs]).any()
        batch_specific = torch.Tensor(self.train_data[pairs_ixs]).to(
            self.config.device
        )  # b, t, c

        if self.config.model_args.combine_inputs:  # for oracle ablation

            batch_true = torch.cat([batch_, batch_specific], dim=-1).clone()  # b, t, 2c

            # apply dropout to each input separately
            batch_ = self.dropout(batch_)  # b, t, c
            batch_specific = self.dropout(batch_specific)  # b, t, c

            # combine two channel inputs
            batch_ = torch.cat([batch_, batch_specific], dim=-1)  # b, t, 2c

            context, context_unpooled = self.encoder(batch_)
            tv, dtv = self.get_timevarying(context_unpooled)

        else:

            # infer system info
            context, _ = self.encoder(batch_)

            # infer specific info
            _, context_unpooled = self.encoder(batch_specific)
            tv, dtv = self.get_timevarying(context_unpooled)

        if (
            self.config.model_args.shared_f_init
        ):  # f init is the output of the f_sys encoder
            h0, start_ix, n_steps = self.init_encoder(
                context_unpooled,
                sample_init=sample_init,
                sample_right_boundary=self.config.model_args.augmentation_args.sample_right_boundary,
            )  # h0: [gru_layers, b, h_dim]
        else:
            h0, start_ix, n_steps = self.init_encoder(
                batch_,
                sample_init=sample_init,
                sample_right_boundary=self.config.model_args.augmentation_args.sample_right_boundary,
            )  # h0: [gru_layers, b, h_dim]

        recon_inputs = self.aug.get_recon_inputs(context, n_steps.max())

        dtv, m = (
            shift_and_mask(dtv, start_ix)
            if self.config.model_args.time_vary_args.include
            else (dtv, None)
        )
        recon_inputs = (
            torch.dstack([recon_inputs, dtv]).contiguous()
            if self.config.model_args.time_vary_args.include
            else recon_inputs
        )

        out, hs = self.recon_net(recon_inputs, h0)

        if self.config.model_args.combine_inputs:  # for oracle ablation
            # batch_true = batch_
            pass
        else:
            batch_true = batch_specific

        true, pred = get_pred_true(
            batch_true,
            out,
            start_ix=start_ix,
            sample_init=sample_init,
        )

        return pred, true, (out, h0, hs, context, start_ix, dtv)

    def run_one_epoch(self, loader, train: bool):
        self.model.train(train)

        with torch.set_grad_enabled(train):
            epoch_loss, epoch_recon = 0, 0
            for batch in loader:
                self.optimizer.zero_grad()
                pred, true, (out, x0, hs, context, start_ix, cdtv) = self.run_one_batch(
                    batch, sample_init=self.config.model_args.sample_init
                )
                loss = self.criterion(pred, true)

                if train:
                    loss.backward()

                    torch.nn.utils.clip_grad_value_(self.model.parameters(), 5)

                    self.optimizer.step()
                    self.scheduler.step()

                epoch_loss += loss.item()
                # epoch_recon += loss_recon.item()

            epoch_loss /= len(loader)

        return epoch_loss, dict(
            h0_max=f"{torch.abs(x0[0]).max():.4f}",
            context_max=f"{torch.abs(context).max():.4f}",
            out_max=f"{torch.abs(pred).max():.4f}",
        )

    def encode_downstream(self, batch):
        context_pool, context_all, _ = self.encoder(batch)
        return context_pool, context_all

    def encode_init(self, batch):
        # context_pool, context_all, _ = self.encoder(batch)
        x0_all = self.init_encoder.init_proj(batch)  # b, t, n
        return x0_all

    def get_pairs_ixs(self, ix, seed=None):

        np.random.seed(seed)

        class_ixs = get_unique_labels_ix(self.train_labels)
        batch_class_ixs = get_unique_labels_ix(self.train_labels[ix])

        pairs_ixs = {}
        for u in np.unique(self.train_labels[ix]):
            candidate_set = np.random.choice(
                class_ixs[u], replace=False, size=len(batch_class_ixs[u]) * 2
            ).tolist()

            for i in ix[batch_class_ixs[u]]:
                c = candidate_set.pop()
                if i == c:
                    c = candidate_set.pop()
                pairs_ixs[i.item()] = c

        pairs_ixs = np.array([pairs_ixs[i.item()] for i in ix])

        return pairs_ixs

    def evaluate(self, dataloader, labels=None):  # modified to work with combined input
        with torch.no_grad():
            self.model.eval()
            results = {"embed": [], "labels": []}

            for batch in dataloader:
                if isinstance(batch, list):
                    batch, labels = batch

                if self.config.model_args.combine_inputs:
                    # batch = torch.cat([batch, torch.zeros_like(batch)], dim=-1)
                    batch = torch.cat([batch, batch], dim=-1)

                out, _ = self.encoder(batch)

                results["embed"].append(out.cpu())
                results["labels"].append(labels.cpu())

            results["embed"] = np.concatenate(results["embed"])
            results["labels"] = np.concatenate(results["labels"])
            return results


def get_unique_labels_ix(arr):
    # Get unique values
    u = np.unique(arr)
    class_ixs = {i: [] for i in u}

    for idx, val in enumerate(arr):
        class_ixs[val].append(idx)

    class_ixs = {k: np.array(v) for k, v in class_ixs.items()}
    return class_ixs


class OracleTimeSeriesDataset(Dataset):
    def __init__(self, timeseries_dataset, labels):
        self.dataset = timeseries_dataset
        self.labels = labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        batch_idx = idx // self.dataset.num_windows

        return self.dataset.__getitem__(idx), self.labels[batch_idx], idx