import torch
import numpy as np

from trainers.base import BaseTrainer
from baselines.timeMAE.timeMAE import TimeMAE, Align, Reconstruct


class TimeMAETrainer(BaseTrainer):
    def __init__(self, config, train_data, val_data):
        super().__init__(config, train_data, val_data)

        self.model = TimeMAE(
            config,
            #   self.encoder
        ).to(self.config.device)

        self.align = Align()
        self.reconstruct = Reconstruct()

        self.all_modules = {
            "model": self.model,
            # "align": self.align,
            #   "reconstruct": self.reconstruct
        }

        self.model.to(self.config.device)

    def run_one_epoch(self, loader, train: bool):
        self.model.train(train)

        with torch.set_grad_enabled(train):
            epoch_loss, epoch_recon = 0, 0
            for batch in loader:
                self.optimizer.zero_grad()

                batch = batch.to(self.config.device)
                loss = self.run_one_batch(batch)
                #   sample_init=False
                #   ) # reconstruction loss
                if train:
                    loss.backward()

                    torch.nn.utils.clip_grad_value_(self.model.parameters(), 5)

                    self.optimizer.step()
                    self.scheduler.step()

                epoch_loss += loss.item()
            epoch_loss /= len(loader)

        return epoch_loss, dict()

    def run_one_batch(self, batch):
        batch = batch.to(self.config.device)
        [rep_mask, rep_mask_prediction], [token_prediction_prob, tokens] = (
            self.model.pretrain_forward(batch)
        )

        align_loss = self.align.compute(rep_mask, rep_mask_prediction)
        reconstruct_loss, _, _ = self.reconstruct.compute(token_prediction_prob, tokens)

        loss = (
            self.config.model_args.alpha * align_loss
            + self.config.model_args.beta * reconstruct_loss
        )
        return loss

    def evaluate(self, dataloader, labels=None):
        with torch.no_grad():
            self.model.eval()
            results = {
                # "pred": [],
                #    "true": [],
                "embed": [],
                "labels": [],
            }

            for batch in dataloader:
                if isinstance(batch, list):
                    batch, labels = batch

                # loss, y, context = self.run_one_batch(batch, )
                embed = self.model(batch)

                # results["pred"].append(y.cpu())
                # results["true"].append(batch.cpu())
                results["embed"].append(embed.cpu())
                results["labels"].append(labels.cpu())

            # results["pred"] = np.concatenate(results["pred"])
            # results["true"] = np.concatenate(results["true"])
            results["embed"] = np.concatenate(results["embed"])
            results["labels"] = np.concatenate(results["labels"])

            return results

    def encode_downstream(self, batch):
        print("BATCH", batch.shape)
        [rep_mask, rep_mask_prediction], [context, tokens] = (
            self.model.pretrain_forward(batch)
        )
        # context = self.model(batch)
        # print("CONTEXT", context.shape)
        return context, None

    def get_encoder(
        self,
    ):
        encoder = TimeMAEFinetuneWrapper(self.model)
        return encoder


import torch.nn as nn


class TimeMAEFinetuneWrapper(nn.Module):
    def __init__(self, model: TimeMAE):
        super().__init__()
        self.model = model

    def forward(self, x):

        print("HERE", self.model.position.pe.weight.shape, x[:, :128].shape)
        [rep_mask, rep_mask_prediction], [context, tokens] = (
            self.model.pretrain_forward(x)
        )
        return context, None

