import torch
import torch.nn as nn
import numpy as np

from trainers.base import BaseTrainer
from transformers import PatchTSTConfig, PatchTSTForPretraining
from transformers.models.patchtst.modeling_patchtst import PatchTSTMasking


class PatchTSTTrainer(BaseTrainer):
    def __init__(self, config, train_data=None, val_data=None):
        super().__init__(config, train_data, val_data)

        # self.model = self.encoder
        self.model = self.init_model(config)
        self.all_modules = {"encoder": self.model}
        self.model.to(self.config.device)

        self.flatten = nn.Flatten(start_dim=1)

    def init_model(self, config):

        model = PatchTSTWrapper(config)
        return model

    def run_one_epoch(self, dataloader: torch.utils.data.DataLoader, train: bool):
        self.model.train()

        self.optimizer.zero_grad()
        total_loss = 0
        for batch in dataloader:
            batch = batch.to(self.config.device)
            b, t, c = batch.shape

            out = self.model.train_forward(batch)
            loss = out.loss

            loss /= b

            if train:
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

            total_loss += loss.item()

        return total_loss, {}

    def evaluate(self, dataloader, labels=None):
        with torch.no_grad():
            self.model.eval()
            results = {"embed": [], "labels": []}

            for batch in dataloader:
                if isinstance(batch, list):
                    batch, labels = batch

                out = self.model.infer_forward(batch)

                results["embed"].append(out.cpu())
                results["labels"].append(labels.cpu())

            results["embed"] = np.concatenate(results["embed"])
            results["labels"] = np.concatenate(results["labels"])
            return results


class PatchTSTWrapper(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patchtst_config = PatchTSTConfig(
            num_input_channels=config.data_args.input_dims,
            context_length=config.data_args.subseq_size,
            use_cls_token=True,
            **config.model_args
        )

        self.model = PatchTSTForPretraining(self.patchtst_config)
        self.backbone = self.model.model
        self.head = self.model.head
        self.flatten = torch.nn.Flatten(start_dim=1, end_dim=-1)

    def infer_forward(self, batch):
        out = self.backbone(past_values=batch)
        embedding = out.last_hidden_state

        pooled_embedding = embedding[:, :, 0, :]
        # pooled_embedding = embedding.mean(dim=2).values()
        # pooled_embedding = embedding.mean(dim=2)
        embed = self.flatten(pooled_embedding)  # turn patch embedding into a vector

        return embed

    def train_forward(self, batch):
        out = self.model(past_values=batch)
        return out

    def train(
        self,
    ):
        self.model.train()
        self.model.model.do_mask_input = True
        self.model.model.masking = PatchTSTMasking(self.patchtst_config)

    def eval(
        self,
    ):
        self.model.eval()
        self.model.model.do_mask_input = False
        self.model.model.masking = nn.Identity()
