from pathlib import Path
project_root = Path(__file__).parent.absolute()

import torch

import hydra
from omegaconf import OmegaConf

import kd
from pl_runner import pl_train
from train import LightningModel
from tee import StdoutTee, StderrTee


class DistillLightningModel(LightningModel):

    def __init__(self, model_cfg, dataset_cfg, train_cfg):
        super().__init__(model_cfg, dataset_cfg, train_cfg)
        path = project_root / train_cfg.teacher_checkpoint_path
        self.teacher = LightningModel.load_from_checkpoint(str(path))
        self.teacher.freeze()
        self.kd_loss = hydra.utils.instantiate(self.train_cfg.kd)

    def training_step(self, batch, batch_idx):
        batch_x, batch_y = batch
        with torch.no_grad():
            teacher_out = self.teacher.model(batch_x)
        out = self.model(batch_x)
        loss_og = self.task.loss(out, batch_y)
        loss = self.kd_loss(out, teacher_out, batch_y, loss_og)
        metrics = self.task.metrics(out, batch_y)
        return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y,
                'progress_bar': metrics, 'log': metrics}


class DistillCrossfitLightningModel(LightningModel):

    def __init__(self, model_cfg, dataset_cfg, train_cfg):
        dataset_cfg.return_indices = True
        super().__init__(model_cfg, dataset_cfg, train_cfg)
        path = project_root / train_cfg.teacher_checkpoint_path
        self.crossfit_size = self.train_cfg.crossfit_size
        self.teachers = []
        for i in range(self.crossfit_size):
            t = LightningModel.load_from_checkpoint(str(path).replace('.ckpt', f'{i}.ckpt'))
            t.freeze()
            self.teachers.append(t)
        self.teachers = torch.nn.ModuleList(self.teachers)  # Need to register so they move to GPU device
        self.kd_loss = hydra.utils.instantiate(self.train_cfg.kd)

    def training_step(self, batch, batch_idx):
        batch_x, batch_y, indices = batch
        mod = indices % self.crossfit_size
        x_list, y_list, teacher_out = [], [], []
        for i in range(self.crossfit_size):
            mask = mod == i
            if not torch.all(mask == False):
                x_list.append(batch_x[mask])
                y_list.append(batch_y[mask])
                with torch.no_grad():
                    teacher_out.append(self.teachers[i].model(x_list[-1]))
        batch_x, batch_y, teacher_out = torch.cat(x_list), torch.cat(y_list), torch.cat(teacher_out)
        out = self.model(batch_x)
        loss_og = self.task.loss(out, batch_y)
        loss = self.kd_loss(out, teacher_out, batch_y, loss_og)
        metrics = self.task.metrics(out, batch_y)
        return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y,
                'progress_bar': metrics, 'log': metrics}


@hydra.main(config_path="cfg", config_name="distill.yaml")
def main(cfg: OmegaConf):
    with StdoutTee('train.stdout'), StderrTee('train.stderr'):
        print(OmegaConf.to_yaml(cfg))
        pl_module_cls = DistillLightningModel if cfg.train.get('crossfit_size', 1) == 1 else DistillCrossfitLightningModel
        if cfg.runner.name == 'pl':
            pl_train(cfg, pl_module_cls)
        else:
            assert cfg.runner.name == 'ray', 'Only pl and ray runners are supported'
            # Shouldn't need to install ray unless doing distributed training
            from ray_runner import ray_train
            ray_train(cfg, pl_module_cls)


if __name__ == "__main__":
    main()
