import os, pickle
import torch

from hydra.utils import instantiate, get_original_cwd
from omegaconf import DictConfig
from transformers import AutoTokenizer

from src.model.base_model import BaseModel
from src.model.task_lm import TaskLanguageModel

from src.utils.losses import kd_criterion_dict, calc_task_loss
from src.utils.metrics import init_best_metrics, init_perf_metrics, calc_preds, process_outputs
from src.utils.optim import setup_scheduler, setup_optimizer_params
from src.utils.logging import log_step_losses, log_epoch_losses, log_epoch_metrics


class LanguageModel(BaseModel):
    def __init__(
            self, arch: str, model_max_length: int, dataset: str, num_classes: int,
            optimizer: DictConfig, scheduler: DictConfig,
            evaluate_ckpt: bool, io_mode: str, aux_io_mode: str,
            kd_input: bool, kd_target: bool, kd_criterion: str, kd_loss_wt: float,
            aux_lm_only: bool, no_task_loss: bool = False, aux_arch: str = None,
            no_bottleneck: bool = False,
            ftr_dropout_rate: float = None,
            **kwargs,
        ):
        super().__init__()

        self.save_hyperparameters()

        if io_mode in ['IshuffledR-O', 'IreplacedR-O']:
            io_mode = 'IR-O'
        if aux_io_mode in ['IshuffledR-O', 'IreplacedR-O']:
            aux_io_mode = 'IR-O'

        self.dataset = dataset
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.evaluate_ckpt = evaluate_ckpt
        self.io_mode = io_mode
        self.aux_io_mode = aux_io_mode
        self.kd_input, self.kd_target = kd_input, kd_target
        self.kd_loss_wt = kd_loss_wt
        self.aux_lm_only = aux_lm_only
        self.no_task_loss = no_task_loss

        assert aux_arch is None or aux_arch != arch
        self.aux_arch = aux_arch

        # Check that args pass asserts
        if kd_input or kd_target:
            assert io_mode == 'I-O' and aux_io_mode == 'IR-O'
            assert kd_loss_wt >= 0
            assert not aux_lm_only
        else:
            assert kd_criterion is None
            assert kd_loss_wt is None

        if aux_lm_only:
            assert io_mode == 'IR-O'
            assert aux_io_mode == 'IR-O'
            assert not (kd_input or kd_target)
            assert aux_arch is None
        else:
            assert io_mode in ['I-O', 'IR-O', 'I-OR', 'I-RO']

        # Initialize task LM
        if not aux_lm_only:
            self.task_lm = TaskLanguageModel(arch, optimizer, io_mode, 'task', kd_input, kd_target, no_bottleneck, ftr_dropout_rate)
        else:
            self.task_lm = None

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(arch, model_max_length=model_max_length)

        # Initialize evaluation metrics
        self.best_metrics = init_best_metrics()
        self.perf_metrics = {'task': init_perf_metrics(num_classes)}

        # Initialize auxiliary LM
        if (kd_input or kd_target) or aux_lm_only:
            self.aux_lm = TaskLanguageModel(arch if aux_arch is None else aux_arch, optimizer, aux_io_mode, 'aux', kd_input, kd_target, no_bottleneck, None)
            if kd_input or kd_target:
                self.kd_criterion = kd_criterion_dict[kd_criterion]()
                self.perf_metrics['aux'] = init_perf_metrics(num_classes)
                if aux_arch is not None:
                    self.task_lm.set_aux_lm(self.aux_lm.lm_dim, self.aux_lm.lm.lm_head)
                else:
                    self.task_lm.set_aux_head(self.aux_lm.lm_dim, self.aux_lm.lm.lm_head)
        else:
            self.aux_lm = None

    def forward(self, batch, split):
        outputs = {}
        outputs['task'] = self.aux_lm(batch, split) if self.aux_lm_only else self.task_lm(batch, split)
        if self.kd_input or self.kd_target:
            outputs['aux'] = self.aux_lm(batch, split)
            
        return outputs

    def run_step(self, batch, split, batch_idx):
        eval_split = batch['split']
        assert not (split == 'train' and split != eval_split)
        ret_dict, loss_dict = {}, {}
        ret_dict['eval_split'] = eval_split

        outputs = self.forward(batch, split)

        if self.io_mode in ['I-OR', 'I-RO'] and split != 'train':
            ret_dict['pred_label'], _ = process_outputs(outputs['task'].sequences, self.io_mode, self.tokenizer)
        else:
            if self.io_mode in ['I-OR', 'I-RO']:
                loss = task_loss = loss_dict['task_loss'] = outputs['task'].loss
            else:
                task_loss = loss_dict['task_loss'] = calc_task_loss(outputs['task'].logits, batch['label'])
                ret_dict['pred_label'] = calc_preds(outputs['task'].logits)
                if self.kd_input or self.kd_target:
                    aux_loss = loss_dict['aux_loss'] = calc_task_loss(outputs['aux'].logits, batch['label'])
                    ret_dict['aux_pred_label'] = calc_preds(outputs['aux'].logits)
                    if self.kd_input:
                        task_states, aux_states = outputs['task'].kd_input_states, outputs['aux'].kd_input_states
                        kd_input_loss = loss_dict['kd_input_loss'] = self.kd_criterion(task_states, aux_states)
                    else:
                        kd_input_loss = 0.0

                    if self.kd_target:
                        task_states, aux_states = outputs['task'].kd_target_states, outputs['aux'].kd_target_states
                        kd_target_loss = loss_dict['kd_target_loss'] = self.kd_criterion(task_states, aux_states)
                    else:
                        kd_target_loss = 0.0

                    kd_loss = loss_dict['kd_loss'] = self.kd_loss_wt * (kd_input_loss + kd_target_loss)
                else:
                    aux_loss = kd_loss = 0.0
                
                loss = (self.no_task_loss == False) * task_loss + kd_loss

            loss_dict['loss'] = loss
            ret_dict = log_step_losses(self, loss_dict, ret_dict, eval_split)

        ret_dict['label'] = batch['label']
        
        return ret_dict

    def aggregate_epoch(self, outputs, split):
        if split == 'train':
            splits = ['train']
        elif split == 'dev':
            splits = ['dev', 'test']
        elif split == 'test':
            splits = [outputs[0]['eval_split']]
        outputs_list = outputs if split == 'dev' else [outputs]
        
        for dataset_idx, eval_split in enumerate(splits):
            outputs = outputs_list[dataset_idx]
            if self.io_mode in ['I-O', 'IR-O'] or eval_split == 'train':
                log_epoch_losses(self, outputs, eval_split) # Log epoch losses
            if self.io_mode in ['I-O', 'IR-O'] or eval_split != 'train':
                # Log epoch metrics
                log_epoch_metrics(self, outputs, eval_split, 'task')
                if self.kd_input or self.kd_target:
                    log_epoch_metrics(self, outputs, eval_split, 'aux')

    def configure_optimizers(self):
        optimizer_params = []
        model = self.aux_lm if self.aux_lm_only else self.task_lm
        optimizer_params += setup_optimizer_params(model, self.optimizer)

        self.optimizer['lr'] = self.optimizer['lr'] * self.trainer.world_size

        optimizer = instantiate(
            self.optimizer, params=optimizer_params,
            _convert_='partial'
        )

        if self.scheduler.lr_scheduler in ('linear_with_warmup', 'constant_with_warmup'):
            scheduler = setup_scheduler(self.scheduler, self.total_steps, optimizer)
            return [optimizer], [scheduler]
        elif self.scheduler.lr_scheduler == 'fixed':
            return [optimizer]
        else:
            raise NotImplementedError
    
    def _load_from_checkpoint(self, ckpt_path, load_aux_lm_only = False):
        loaded_model = LanguageModel.load_from_checkpoint(ckpt_path, strict = False, map_location = 'cpu')
        if load_aux_lm_only:
            if self.aux_lm is not None:
                self.aux_lm.load_state_dict(loaded_model.aux_lm.state_dict(), strict = True)
            if self.task_lm is not None and self.aux_arch is None:
                self.task_lm.load_state_dict(loaded_model.aux_lm.state_dict(), strict = True)
        else:
            self.load_state_dict(loaded_model.state_dict(), strict = True)
        loaded_model = None