import collections
import os
from logging import getLogger
from pathlib import Path
from typing import List

import torch
import torch.nn as nn

from src.criterion import KLDivLoss
from src.evaluator import DistributedEvaluator
from src.modeling_abstract import DistributedModule
from src.tokenizer import Tokenizer
from src.utils import barrier, reconstruct_logits_from_dicts

logger = getLogger()


class DistributedTrainer:
    def __init__(self,
                 model: DistributedModule,
                 tokenizer: Tokenizer,
                 optimizer: torch.optim.Optimizer,
                 eval_batch_size: int,
                 max_gen_len: int = 512,
                 accumulation_steps: int = 1,
                 log_dir: str = "log/"):
        self.model = model
        self.local_rank = model.local_rank
        self.world_size = model.world_size
        self.max_seq_len = self.model.params.max_seq_len
        self.max_gen_len = max_gen_len
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.evaluator = DistributedEvaluator(self.model, tokenizer)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.criterion_kl = KLDivLoss()
        self.step = 0
        self.accumulation_steps = accumulation_steps
        self.eval_batch_size = eval_batch_size
        self.log_dir = log_dir

    def _truncating_strategy(self, instruction_ids, output_ids):
        instruction_length = len(instruction_ids)
        output_length = len(output_ids)
        if instruction_length >= self.max_seq_len:
            print(f'WARNING: Length of instruction {instruction_length} '
                  f'exceeds the max input length {self.max_seq_len}')
            instruction_ids = instruction_ids[:self.max_seq_len]
            instruction_length = len(instruction_ids)
        sequence_length = instruction_length + output_length
        if sequence_length > self.max_seq_len:
            exceed_length = sequence_length - self.max_seq_len
            output_ids = output_ids[:-exceed_length]
        return instruction_ids, output_ids

    def _back_propagation(self, loss: torch.Tensor):
        self.step += 1
        loss = loss / self.accumulation_steps
        loss.backward()
        if self.step % self.accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def prepare_for_training(self, instructions, outputs):
        """ :return tokens, labels """
        bsz = len(instructions)
        tokens = torch.full((bsz, self.max_seq_len), self.tokenizer.pad_id).long()
        labels = torch.full((bsz, self.max_seq_len), -100).long()
        for i, (instruction, output) in enumerate(zip(instructions, outputs)):
            instruction_ids = self.tokenizer.encode(instruction, bos=True, eos=False)
            output_ids = self.tokenizer.encode(output, bos=False, eos=True)
            instruction_ids, output_ids = self._truncating_strategy(instruction_ids, output_ids)
            instr_len, output_len = len(instruction_ids), len(output_ids)
            tokens[i, :instr_len + output_len] = torch.tensor(instruction_ids + output_ids).long()
            labels[i, instr_len - 1: instr_len - 1 + output_len] = torch.tensor(output_ids).long()
        Output = collections.namedtuple('Outputs', ['tokens', 'labels'])
        return Output(tokens=tokens, labels=labels)

    def prepare_for_distilling(self, instructions, outputs, logits_dicts_list):
        bsz = len(instructions)
        labels = torch.full((bsz, self.max_seq_len), -100).long()
        logits = torch.full((bsz, self.max_seq_len, self.tokenizer.n_words), 0).float()
        for i, (instruction, output, logits_dicts) in enumerate(zip(instructions, outputs, logits_dicts_list)):
            instruction_ids = self.tokenizer.encode(instruction, bos=True, eos=False)
            output_ids = self.tokenizer.encode(output, bos=False, eos=True)
            instruction_ids, output_ids = self._truncating_strategy(instruction_ids, output_ids)
            instr_len, output_len = len(instruction_ids), len(output_ids)
            logits_dicts = logits_dicts[: output_len]
            assert output_len == len(logits_dicts)
            logits[i, instr_len - 1: instr_len - 1 + output_len, :] = reconstruct_logits_from_dicts(
                logits_dicts, self.tokenizer.n_words)
            labels[i, instr_len - 1: instr_len - 1 + output_len] = torch.tensor(output_ids).long()
        label_masks = (labels != -100)
        probs = torch.softmax(logits, dim=-1)
        Output = collections.namedtuple('Outputs', ['teacher_logits', 'teacher_probs', 'label_masks'])
        return Output(teacher_logits=logits, teacher_probs=probs, label_masks=label_masks)

    @torch.no_grad()
    def predict(self, logits, instructions: List[str], outputs: List[str]) -> List[dict]:
        bzs = int(logits.shape[0])
        datalist = []
        for i in range(bzs):
            instruction_ids = self.tokenizer.tokenize(instructions[i], bos=True)
            output_ids = self.tokenizer.tokenize(outputs[i], eos=True)
            instruction_ids, output_ids = self._truncating_strategy(instruction_ids, output_ids)
            instr_len, output_len = len(instruction_ids), len(output_ids)
            predict_ids = torch.argmax(logits[i], dim=-1)[instr_len - 1: instr_len - 1 + output_len].tolist()
            datalist.append(dict(instruction=instructions[i], output=self.tokenizer.decode(predict_ids)))
        return datalist

    def train(self, instructions: List[str], outputs: List[str]):
        """ Instruction tuning """
        example = self.prepare_for_training(instructions=instructions, outputs=outputs)
        logits = self.model.forward(example.tokens)
        loss = self.criterion.forward(
            input=logits.view(-1, logits.size(-1)),
            target=example.labels.view(-1).to(logits.device)
        )
        self._back_propagation(loss)
        Output = collections.namedtuple('Output', ['loss', 'logits'])
        return Output(logits=logits, loss=loss)

    def distill(
            self,
            instructions: List[str],
            outputs: List[str],
            logits_dicts_list: List[dict],
            alpha: float,
            logits_dicts_list2: List[dict] = None,
            beta: float = None
    ):
        forward_example = self.prepare_for_training(instructions=instructions, outputs=outputs)
        logits = self.model.forward(forward_example.tokens)
        ce_loss = self.criterion.forward(
            input=logits.view(-1, logits.size(-1)),
            target=forward_example.labels.view(-1).to(logits.device)
        )

        distill_example = self.prepare_for_distilling(
            instructions, outputs, logits_dicts_list
        )
        distill_loss = self.criterion_kl.forward(
            logits=logits,
            targets=distill_example.teacher_probs.to(logits.device),
            masks=distill_example.label_masks.to(logits.device)
        )
        loss = ce_loss + alpha * distill_loss

        distill_loss2 = None
        if logits_dicts_list2 is not None:
            distill_example2 = self.prepare_for_distilling(
                instructions, outputs, logits_dicts_list2
            )
            distill_loss2 = self.criterion_kl.forward(
                logits=logits,
                targets=distill_example2.teacher_probs.to(logits.device),
                masks=distill_example2.label_masks.to(logits.device)
            )
            loss = loss + beta * distill_loss2

        self._back_propagation(loss)
        Output = collections.namedtuple('Output', ['loss', 'distill_loss', 'logits', 'distill_loss2'])
        return Output(logits=logits, distill_loss=distill_loss, loss=ce_loss, distill_loss2=distill_loss2)

    def evaluate(
            self,
            task: str,
            label_file,
            output_file,
    ):
        if not os.path.exists(self.log_dir) and self.local_rank == 0:
            os.makedirs(self.log_dir)
        barrier()
        output_file = os.path.join(self.log_dir, output_file)
        datalist = self.evaluator.generate(
            label_file=label_file,
            batch_size=self.eval_batch_size,
            max_seq_len=self.max_seq_len
        )
        return self.evaluator.evaluating(
            datalist=datalist,
            task=task,
            output_file=output_file
        )

    def save_distributed_optimizer(self, save_path: str):
        if self.local_rank == 0:
            os.makedirs(save_path, exist_ok=True)
        print(f'Saving optimizer to {save_path} ......')
        barrier()
        torch.save(self.optimizer.state_dict(), os.path.join(
            save_path, f'optimizer.0{self.local_rank}.bin'))
        barrier()
        print(f'Saving done !')

    def load_distributed_optimizer(self, save_path: str):
        checkpoints = sorted(Path(save_path).glob("optimizer.*.bin"))
        if len(checkpoints) == 0:
            return
        print(f'Loading optimizer from {save_path} .....')
        assert self.world_size == len(
            checkpoints
        ), f"Loading a optimizer for MP={len(checkpoints)} but world size is {self.world_size}"
        optim_file = checkpoints[self.local_rank]
        state_dict = torch.load(optim_file)
        self.optimizer.load_state_dict(state_dict)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()
        print(f'Loading done !')

    def save_distributed_model(self, save_path: str):
        self.model.save(save_path)

    def load_distributed_model(self, save_path: str):
        self.model.load(save_path)

    def load(self, save_path: str):
        if save_path is None or save_path.lower() == "none":
            print("WARNING: Not loading model because `save_path` is None")
            return
        self.load_distributed_optimizer(save_path)
        self.load_distributed_model(save_path)

    def save(self, save_path: str):
        if save_path is None or save_path.lower() == "none":
            print("WARNING: Not saving model because `save_path` is None")
            return
        self.save_distributed_optimizer(save_path)
        self.save_distributed_model(save_path)
