import numpy as np
import torch.nn as nn

# from utils.io import load_trainer
from torch.utils.data import TensorDataset, DataLoader
import torch

import trainers.all_trainers as all_trainers

from pathlib import Path
from utils.dataset import get_trainer_kwargs
from utils.logging import get_logger

from utils.dataset import get_eval_dataloader, SPLIT_LIST


from utils.io import get_full_config
from utils.dataset import get_data_from_config

from trainers.base import BaseTrainer
from pulse.mlp import MLP
from torch.optim.swa_utils import AveragedModel, SWALR

logger = get_logger()


class TransferTrainer(BaseTrainer):
    def __init__(self, transfer_config, train_data, train_labels, val_data, val_labels):
        self.train_labels = train_labels
        self.val_labels = val_labels
        self.transfer_config = transfer_config

        super().__init__(transfer_config, train_data, val_data)

        self.ckpt_trainer, self.ckpt_config = load_transfer_backbone_ckpt(
            self.transfer_config
        )
        self.load_heads(self.transfer_config, self.ckpt_config)
        self.encoder = self.ckpt_trainer.get_encoder()

        self.all_modules = {
            "encoder": self.encoder,
            "in_head": self.in_head,
            "out_head": self.out_head,
        }

        self.model = nn.ModuleDict(self.all_modules)
        self.model.to(self.transfer_config.device)

        self.criterion = torch.nn.CrossEntropyLoss()
        # self.dropout = nn.Dropout(0.2)

        self.current_epoch = 0

        # self.swa_start = 10
        # self.swa_model = AveragedModel(self.model)
        # swa_scheduler = SWALR(self.optimizer, swa_lr=0.05)
        self.freeze_module(self.encoder)

    def run_one_batch(
        self,
        batch,
    ):
        batch, labels = batch
        batch = batch.to(self.transfer_config.device)
        labels = labels.to(self.transfer_config.device).long()

        x = self.in_head(batch)
        embed, _ = self.encoder(x)
        # embed = self.dropout(embed)
        pred_logit = self.out_head(embed)

        # print(batch.shape, pred_logit.shape, labels.shape)

        # if len(pred_logit.shape) == 1:
        # pred_logit = pred_logit.unsqueeze(0)

        loss = self.criterion(pred_logit, labels)
        return loss, pred_logit, embed

    # def run_one_batch_eval(self, batch, ):
    #     batch, labels = batch
    #     batch = batch.to(self.transfer_config.device)
    #     labels = labels.to(self.transfer_config.device).long()

    #     x = self.swa_model.module.in_head(batch)
    #     embed, _ = self.swa_model.module.encoder(x)
    #     # embed = self.dropout(embed)
    #     pred_logit = self.swa_model.module.out_head(embed)

    #     loss = self.criterion(pred_logit, labels)
    #     return loss, pred_logit, embed

    def run_one_epoch(self, loader, train):

        if (
            self.current_epoch
            == self.transfer_config.training_args.freeze_backbone_epochs
        ):
            logger.info(f"Unfreezing backbone at epoch {self.current_epoch}")
            self.unfreeze_module(self.encoder)

        self.model.train(train)
        with torch.set_grad_enabled(train):
            epoch_loss = 0
            for batch in loader:
                if train:
                    self.optimizer.zero_grad()

                loss, _, _ = self.run_one_batch(batch)

                if train:
                    loss.backward()
                    self.optimizer.step()
                    self.scheduler.step()

                    # swa_scheduler.step()

                epoch_loss += loss.item()

            # if self.current_epoch > self.swa_start:
            # self.swa_model.update_parameters(self.model)   # keep running average

            epoch_loss /= len(loader)

        self.current_epoch += 1

        return epoch_loss, dict()

    def setup_dataloader(self, data, labels, train: bool):
        dataset = TensorDataset(
            torch.from_numpy(data).to(torch.float),
            torch.from_numpy(labels).to(torch.float),
        )
        loader = DataLoader(
            dataset,
            batch_size=self.transfer_config.training_args.batch_size,
            shuffle=train,
            num_workers=torch.get_num_threads(),
        )
        return loader

    def load_heads(self, transfer_config, ckpt_config):
        # self.in_head = nn.Linear(transfer_config.data_args.input_dims, ckpt_config.data_args.input_dims)
        # self.out_head = nn.Linear(ckpt_config.encoder_args.emb_dim, transfer_config.data_args.num_classes)
        # self.in_head = MLP(transfer_config.data_args.input_dims, 8, ckpt_config.data_args.input_dims,
        #  activation=torch.nn.Tanh()
        #  )

        if transfer_config.data_args.input_dims == ckpt_config.data_args.input_dims:
            self.in_head = nn.Identity()
        else:
            self.in_head = nn.Linear(
                transfer_config.data_args.input_dims, ckpt_config.data_args.input_dims
            )
            # self.in_head = MLP(transfer_config.data_args.input_dims, 4, ckpt_config.data_args.input_dims,
            #  activation=torch.nn.Tanh()
            #  )

        self.out_head = MLP(
            ckpt_config.encoder_args.emb_dim,
            64,
            transfer_config.data_args.num_classes,
            activation=torch.nn.Tanh(),
        )

    def freeze_module(self, module):
        for param in module.parameters():
            param.requires_grad = False

    def unfreeze_module(self, module):
        for param in module.parameters():
            param.requires_grad = True

    def evaluate(self, loader):
        with torch.no_grad():
            self.model.eval()
            results = {
                "embed": [],
                "labels": [],
                "pred_proba": [],
                "pred_labels": [],
            }
            for batch in loader:
                _, labels = batch

                loss, out, context = self.run_one_batch(
                    batch,
                )
                # loss, out, context = self.run_one_batch_eval(batch, )

                probs = torch.softmax(out.cpu(), dim=-1)
                preds = torch.argmax(probs, dim=-1)

                results["pred_proba"].append(probs)
                results["pred_labels"].append(preds)
                results["embed"].append(context.cpu())
                results["labels"].append(labels.cpu())

            results["pred_proba"] = np.concatenate(results["pred_proba"])
            results["pred_labels"] = np.concatenate(results["pred_labels"])
            results["embed"] = np.concatenate(results["embed"])
            results["labels"] = np.concatenate(results["labels"])

            # "pred": [], "true": [],
            # if isinstance(batch, list):
            # results["true"].append(batch[0].cpu())
            # results["pred"] = np.concatenate(results["pred"])
            # results["true"] = np.concatenate(results["true"])

        # print("PRED_LABELS" , results["pred_labels"].shape)

        return results

    def setup_optimizer(
        self,
    ):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config.training_args.lr, weight_decay=1e-6
        )
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=0.9999
        )
        # self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer,
        #  max_lr=self.config.training_args.lr,
        #  total_steps=len(self.train_loader)*self.config.training_args.epochs)

    # def setup_optimizer(self,):
    # self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.training_args.lr)
    # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, threshold=1e-4, min_lr=1e-5)


def load_transfer_backbone_ckpt(transfer_config):

    ckpt_dir = Path(transfer_config.load_from_checkpoint)
    ckpt_config_path = ckpt_dir / "config.yaml"
    ckpt_config = get_full_config(ckpt_config_path)
    ckpt_data, _ = get_data_from_config(ckpt_config, ckpt_config.data_args.mode)
    ckpt_trainer_kwargs = get_trainer_kwargs(ckpt_config, ckpt_data)

    ckpt_trainer = all_trainers.all_trainers[ckpt_config.model_type](
        ckpt_config, **ckpt_trainer_kwargs
    )  # load trainer
    state_dict = torch.load(
        ckpt_dir / f"model_state.pt",
        weights_only=False,
        map_location=ckpt_trainer.config.device,
    )

    for k, v in state_dict.items():  # load weights
        logger.info(f"Loading weights for {k}")
        ckpt_trainer.all_modules[k].load_state_dict(v)

    return ckpt_trainer, ckpt_config
