import bisect
import os
import shutil
from collections import deque
from typing import List

import torch
from torch.cuda import amp
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

import dataset
import logger
from args import TrainerArguments
from logger import AvgMetricTracker, MetricTracker, BestMetricTracker
from model import DSLFMKGC
from utils import compute_hits, move_to_cuda, get_model_obj, get_autocast_context_manager


class Trainer:
    def __init__(
        self,
        train_args: TrainerArguments,
        model: DSLFMKGC,
        train_dataset: dataset.Dataset,
        valid_dataset: dataset.Dataset,
    ) -> None:
        self.args = train_args
        self.global_step = 0
        self.epoch_start = 0

        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(model).cuda()
        elif torch.cuda.is_available():
            self.model = model.cuda()
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=train_args.batch_size,
            shuffle=True,
            collate_fn=dataset.collate,
            pin_memory=True,
            num_workers=train_args.workers,
            drop_last=True,
        )
        self.valid_dataloader = DataLoader(
            valid_dataset,
            batch_size=train_args.batch_size * 4,
            shuffle=False,
            collate_fn=dataset.collate,
            pin_memory=True,
            num_workers=train_args.workers,
        )
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=train_args.lr, weight_decay=train_args.weight_decay
        )
        self.num_steps_per_epoch = len(self.train_dataloader) // train_args.gradient_accumulation_steps
        self.num_training_steps = train_args.epochs * self.num_steps_per_epoch
        warmup = min(train_args.warmup, self.num_training_steps // 10)
        if train_args.scheduler == "cosine":
            self.scheduler = get_cosine_schedule_with_warmup(self.optimizer, warmup, self.num_training_steps)
        else:
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer, warmup, num_training_steps=self.num_training_steps
            )
        if train_args.use_amp:
            self.scaler = amp.GradScaler()
        self.autocast_context_manager = get_autocast_context_manager(train_args.use_amp)

        # number of models with best evaluation metrics to save
        self.n_best = train_args.n_best_to_track
        self.best_metric_list = deque()

    def train_loop(self) -> None:
        self.best_metric_tracker = BestMetricTracker("best_Hit@1", "valid")
        for epoch in range(self.epoch_start, self.args.epochs + self.epoch_start):
            # train for one epoch
            self.train_epoch(epoch)
            self.eval(epoch + 1)
        # save the last model
        model = get_model_obj(self.model)
        torch.save(model.state_dict(), self.args.last_model_path)

    def train_epoch(self, epoch: int) -> None:
        loss_tracker = AvgMetricTracker("loss", "train")
        loss_kl_tracker = AvgMetricTracker("loss_kl", "train")
        loss_recon_tracker = AvgMetricTracker("loss_recon", "train")
        loss_contrastive_tracker = AvgMetricTracker("loss_contrastive", "train")
        lr_tracker = MetricTracker("lr", "train")

        if self.args.gradient_accumulation_steps > 1:
            inputs_accum, triples_accum = [], []
            hr_vectors_accum, tail_vectors_accum, head_vectors_accum = [], [], []
        for bid, inputs in enumerate(self.train_dataloader):
            model = self.model.train()

            if self.args.gradient_accumulation_steps == 1:
                self.global_step += 1
                triples = inputs.pop("triples")
                inputs = move_to_cuda(inputs)
                logit_masks = move_to_cuda(
                    {
                        "logits_mask": dataset.construct_logits_mask(triples),
                        "self_neg_logits_mask": dataset.construct_self_negative_mask(triples)
                        if self.args.use_self_negative
                        else None,
                    }
                )

                with self.autocast_context_manager():
                    output_dict = model(**inputs)
                    head_vector = output_dict["head_lfrm"]["X_recon"] if output_dict["head_lfrm"] else None

                    total_loss_dict = get_model_obj(model).compute_kl_loss(**output_dict)
                    total_loss_dict.update(
                        get_model_obj(model).compute_contrastive_loss(
                            hr_vector=output_dict["hr_lfrm"]["X_recon"],
                            tail_vector=output_dict["tail_lfrm"]["X_recon"],
                            head_vector=head_vector,
                            **logit_masks,
                        )
                    )
                    total_loss = total_loss_dict["loss_lfrm"] + total_loss_dict["loss_contrastive"]
                self.backward(total_loss)
            else:
                triples = inputs.pop("triples")
                triples_accum.extend(triples)
                inputs = move_to_cuda(inputs)
                inputs_accum.append(inputs)
                # First, cache the features without any gradient tracking.
                with torch.inference_mode():
                    with self.autocast_context_manager():
                        output_dict = model(**inputs)
                hr_vectors_accum.append(output_dict["hr_lfrm"]["X_recon"])
                tail_vectors_accum.append(output_dict["tail_lfrm"]["X_recon"])
                head_vectors_accum.append(
                    output_dict["head_lfrm"]["X_recon"] if output_dict["head_lfrm"] else None
                )

                if (bid + 1) % self.args.gradient_accumulation_steps > 0:
                    continue

                self.global_step += 1
                logit_masks = move_to_cuda(
                    {
                        "logits_mask": dataset.construct_logits_mask(triples_accum),
                        "self_neg_logits_mask": dataset.construct_self_negative_mask(triples_accum)
                        if self.args.use_self_negative
                        else None,
                    }
                )
                total_loss = 0
                total_loss_dict = {}
                for i in range(self.args.gradient_accumulation_steps):
                    with self.autocast_context_manager():
                        output_dict = model(**inputs_accum[i])
                        kl_loss_dict = get_model_obj(model).compute_kl_loss(**output_dict)
                        if not total_loss_dict:
                            total_loss_dict.update(kl_loss_dict)
                        else:
                            for k, v in kl_loss_dict.items():
                                total_loss_dict[k] += v / self.args.gradient_accumulation_steps
                        total_loss += total_loss_dict["loss_lfrm"]

                        hr_vector = self._get_accumulated_inputs(
                            hr_vectors_accum, output_dict["hr_lfrm"]["X_recon"], i
                        )
                        tail_vector = self._get_accumulated_inputs(
                            tail_vectors_accum, output_dict["tail_lfrm"]["X_recon"], accum_idx=i
                        )
                        head_vector = (
                            None
                            if output_dict["head_lfrm"] is None
                            else self._get_accumulated_inputs(
                                head_vectors_accum,
                                output_dict["head_lfrm"]["X_recon"],
                                i,
                            )
                        )
                        contrastive_loss_dict = get_model_obj(model).compute_contrastive_loss(
                            hr_vector,
                            tail_vector,
                            head_vector=head_vector,
                            **logit_masks,
                        )
                        total_loss += (
                            contrastive_loss_dict["loss_contrastive"] / self.args.gradient_accumulation_steps
                        )

                        del hr_vector, tail_vector, head_vector

                total_loss_dict.update(contrastive_loss_dict)
                self.backward(total_loss)

            if self.args.gradient_accumulation_steps > 1:
                inputs_accum, triples_accum = [], []
                hr_vectors_accum, tail_vectors_accum, head_vectors_accum = [], [], []
            batch_size = total_loss_dict["logits"].size(0)
            loss_tracker.update(total_loss.item(), batch_size)
            loss_kl_tracker.update(total_loss_dict["loss_kl"].item(), batch_size)
            loss_recon_tracker.update(total_loss_dict["loss_recon"].item(), batch_size)
            loss_contrastive_tracker.update(total_loss_dict["loss_contrastive"].item(), batch_size)
            lr_tracker.update(self.scheduler.get_last_lr()[0])

            if self.global_step % self.args.log_step == 0:
                epoch_prop = self.global_step / self.num_steps_per_epoch
                logger.log(
                    f"Step: {self.global_step}/{self.num_training_steps}, "
                    + f"Epoch: {epoch_prop:.2f}/{self.args.epochs}"
                )
                logger.log_metric(
                    loss_tracker,
                    loss_kl_tracker,
                    loss_recon_tracker,
                    loss_contrastive_tracker,
                    lr_tracker,
                    step=self.global_step,
                )
            if self.global_step % self.args.eval_step == 0:
                self.eval(epoch)

    @torch.inference_mode()
    def eval(self, epoch: int) -> None:
        model = self.model.eval()
        dataset.set_eval_mode(True)
        loss_tracker = AvgMetricTracker("loss_contrastive", "valid")
        hit1_tracker = AvgMetricTracker("Hit@1", "valid")
        hit3_tracker = AvgMetricTracker("Hit@3", "valid")
        hit10_tracker = AvgMetricTracker("Hit@10", "valid")

        for _, inputs in enumerate(self.valid_dataloader):
            triples = inputs.pop("triples")
            inputs = move_to_cuda(inputs)
            extra_inputs = move_to_cuda(
                {
                    "logits_mask": dataset.construct_logits_mask(triples),
                    "labels": torch.arange(len(triples)),
                }
            )
            with self.autocast_context_manager():
                output_dict = model(**inputs)
                total_loss_dict = get_model_obj(model).compute_contrastive_loss(
                    hr_vector=output_dict["hr_lfrm"]["X_recon"],
                    tail_vector=output_dict["tail_lfrm"]["X_recon"],
                    logits_mask=extra_inputs["logits_mask"],
                )

            logits, loss = total_loss_dict["logits"], total_loss_dict["loss_contrastive"]
            hit1, hit3, hit10 = compute_hits(logits, extra_inputs["labels"], topk=(1, 3, 10))
            batch_size = logits.size(0)
            loss_tracker.update(loss.item(), batch_size)
            hit1_tracker.update(hit1, batch_size)
            hit3_tracker.update(hit3, batch_size)
            hit10_tracker.update(hit10, batch_size)
        self.best_metric_tracker.update(hit1_tracker.value, self.global_step)

        logger.log_metric(
            loss_tracker,
            hit1_tracker,
            hit3_tracker,
            hit10_tracker,
            self.best_metric_tracker,
            step=self.global_step,
        )
        # self._maybe_save(self.best_metric_tracker.metric, epoch)
        dataset.set_eval_mode(False)

    def backward(self, loss: torch.FloatTensor) -> None:
        self.optimizer.zero_grad()
        if self.args.use_amp:
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
            self.optimizer.step()
        self.scheduler.step()

    def save_checkpoint(self, path: str, epoch: int) -> None:
        sd = {
            "model": get_model_obj(self.model).state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "scaler": self.scaler.state_dict(),
            "epoch": epoch,
            "best_metric_list": self.best_metric_list,
        }
        torch.save(sd, path)

    def restore_checkpoint(self, path: str) -> None:
        sd = torch.load(path, self.args.device)
        model = get_model_obj(self.model)
        model.load_state_dict(sd["model"])
        self.optimizer.load_state_dict(sd["optimizer"])
        self.scheduler.load_state_dict(sd["scheduler"])
        self.scaler.load_state_dict(sd["scaler"])
        self.epoch_start = sd["epoch"]
        self.num_training_steps = (self.args.epochs) * self.num_steps_per_epoch

    def _maybe_save(self, best_metric: BestMetricTracker.BestMetric, epoch: int) -> None:
        step, metric = best_metric.step, best_metric.value
        if best_metric in self.best_metric_list:
            return

        bisect.insort(self.best_metric_list, best_metric)
        # remove the model state with the worst metric
        path_template = os.path.join(self.args.model_save_dir, "best{m}_step{s}.pth")
        if len(self.best_metric_list) > self.n_best:
            metric, step = self.best_metric_list.popleft().get()
            path = path_template.format(m=metric, s=step)
            os.remove(path)
        # save models in the list
        for metric, step in map(lambda x: x.get(), self.best_metric_list):
            path = path_template.format(m=metric, s=step)
            if not os.path.exists(path):
                self.save_checkpoint(path, epoch)
        # create symbolic link to the best model
        metric, step = self.best_metric_list[-1].get()
        shutil.copy(path_template.format(m=metric, s=step), self.args.best_model_path)

    def _get_accumulated_inputs(
        self, vector_accum: List[torch.Tensor], vector_input: torch.Tensor, accum_idx: int
    ) -> torch.Tensor:
        return torch.cat(
            vector_accum[:accum_idx] + [vector_input] + vector_accum[accum_idx + 1 :],
            dim=0,
        )
