"""
Transformer for RF Source Separation.
"""

import os
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import wandb
from ml_collections import ConfigDict
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from einops import rearrange

from rfcutils2.utils import get_demod_soi
import utils.data_transforms as data_transforms
from models.transformer_decoder import Transformer as TransformerDecoder
from models.transformer import Transformer
from models.transformer_quant import QuantOutputTransformer
from models.wavenet import Wave
from utils.class_builder import ClassBuilder
from utils.lr_schedulers import (
    CosineAnnealingWarmUp,
    IdentityScheduler,
    ReduceLROnPlateau,
)
from utils.rf_dataset import RFDataset, RFMixtureDataset, DeterministicDataset
from utils.utils import (
    MyDistributedDataParallel,
    get_train_val_dataset,
    nested_to_device,
    describe_tensor,
    limit_iterable,
)
from configs.base_configs import ModelType
from eval import eval_model, load_meta
from utils.plots import plot_mse, plot_ber

torch.backends.cudnn.benchmark = True


MODELS_REGISTER = {
    "Transformer": Transformer,
    "TransformerDecoder": TransformerDecoder,
    "QuantOutputTransformer": QuantOutputTransformer,
    "Wave": Wave,
}
image_transforms_builder = ClassBuilder(MODELS_REGISTER)


OPTIMIZER_REGISTER = {
    "Adam": torch.optim.Adam,
    "AdamW": torch.optim.AdamW,
}
optimizer_builder = ClassBuilder(OPTIMIZER_REGISTER)


LR_SCHEDULER_REGISTER = {
    "IdentityScheduler": IdentityScheduler,
    "CosineAnnealingWarmUp": CosineAnnealingWarmUp,
    "ReduceLROnPlateau": ReduceLROnPlateau,
}
lr_scheduler_builder = ClassBuilder(LR_SCHEDULER_REGISTER)


DATASET_REGISTER = {
    "RFDataset": RFDataset,
    "RFMixtureDataset": RFMixtureDataset,
    "DeterministicDataset": DeterministicDataset,
}
dataset_register = ClassBuilder(DATASET_REGISTER)


def token_cross_entropy(softmax, target):
    return F.cross_entropy(softmax.view(-1, softmax.size(-1)), target.view(-1))


class Learner:
    def __init__(self, model: nn.Module, cfg: ConfigDict):
        # Store some important variables
        self.rank = int(os.environ["LOCAL_RANK"])
        self.cfg = cfg
        self.step = 0

        # Instantiate the dataloaders
        self.build_dataloaders()

        # Store the model
        self.model = model

        # Build the optimizer
        self.build_optimizer()

        # Instantiate the leanring rate scheduler
        self.lr_scheduler, _ = lr_scheduler_builder.build(
            cfg.lr_scheduler_config, optimizer=self.optimizer, last_epoch=self.step - 1
        )

        assert not (
            cfg.trainer_config.fp16 and cfg.trainer_config.bf16
        ), "Cannot use both FP16 and BF16 at the same time."

        enable_mixed_precision = cfg.trainer_config.fp16 or cfg.trainer_config.bf16
        dtype = torch.float16 if cfg.trainer_config.fp16 else torch.bfloat16
        self.autocast = torch.cuda.amp.autocast(
            enabled=enable_mixed_precision, dtype=dtype
        )
        self.scaler = torch.cuda.amp.GradScaler(enabled=enable_mixed_precision)
        self.best_metrics = {}

        if self.is_master:
            wandb.login(key=os.environ.get("WANDB_API_KEY"))
            wandb.init(
                config=cfg.to_dict(),
                dir=cfg.trainer_config.model_dir,
                entity="wandb-entity-fill-me",
                project=self.get_project_name(),
                name=cfg.trainer_config.model_dir.split("/")[-1],
                mode="online",
                save_code=True,
            )

        self.cfg.trainer_config.model_dir = os.path.expandvars(self.cfg.trainer_config.model_dir)
        if os.path.exists(self.last_ckpt_filename()):
            self.load(self.last_ckpt_filename())
        if self.is_master:
            os.makedirs(cfg.trainer_config.model_dir, exist_ok=True)

    def get_project_name(self):
        if self.cfg.testmode:
            return "Testrun"
        elif self.cfg.model_type == ModelType.QUANTIZED_LLM or self.cfg.model_type == ModelType.QUANTIZED:
            return "Quantized_Transformer"
        elif self.cfg.model_type == ModelType.WAVENET:
            return "Wavenet"
        elif self.cfg.model_type == ModelType.WINDOWS:
            return "Transformer_Decoder"
        elif self.cfg.model_type == ModelType.WINDOWS_LLM:
            return "Transformer"
        else:
            raise ValueError("Unknown model type")

    @property
    def is_master(self):
        return self.rank == 0

    def build_optimizer(self):
        # Create param dict and filter out the ones with no grad
        param_dict = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        # Karpathy says to apply weight decay only to 2D parameters, i.e., not layer
        # normalization and bias terms
        decay_params = [p for _, p in param_dict.items() if p.dim() >= 2]
        no_decay_params = [p for _, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {
                "params": decay_params,
                "weight_decay": self.cfg.optimizer_config[1].weight_decay,
            },
            {"params": no_decay_params, "weight_decay": 0.0},
        ]

        if self.rank == 0:
            num_decay_params = sum(p.numel() for p in decay_params)
            num_no_decay_params = sum(p.numel() for p in no_decay_params)
            print(
                f"Number of decay parameter tensors: {len(decay_params)} "
                f"with {num_decay_params} parameters."
            )
            print(
                f"Number of no decay parameter tensors: {len(no_decay_params)} "
                f"with {num_no_decay_params} parameters."
            )

        # Create the optimizer
        self.optimizer, _ = optimizer_builder.build(
            self.cfg.optimizer_config,
            params=optim_groups,
            fused=True,
        )

    def build_dataloaders(self):
        self.dataset, _ = dataset_register.build(
            self.cfg.dataset_config,
        )

        if "val_dataset" in self.cfg:
            print("Validation dataset detected")
            self.train_dataset = self.dataset
            self.val_dataset, _ = dataset_register.build(self.cfg.val_dataset)
        else:
            print("Splitting train and validation")
            self.train_dataset, self.val_dataset = get_train_val_dataset(
                self.dataset, self.cfg.trainer_config.train_fraction
            )

        def get_dataloader(dataset):
            return DataLoader(
                dataset,
                batch_size=self.cfg.trainer_config.batch_size,
                shuffle=not self.cfg.trainer_config.distributed,
                num_workers=(
                    self.cfg.trainer_config.num_workers
                    if self.cfg.trainer_config.distributed
                    else 0
                ),
                sampler=DistributedSampler(
                    dataset,
                    num_replicas=self.cfg.trainer_config.world_size,
                    rank=self.rank,
                )
                if self.cfg.trainer_config.distributed
                else None,
                pin_memory=True,
            )

        self.train_dataloader = get_dataloader(self.train_dataset)
        self.val_dataloader = get_dataloader(self.val_dataset)

    def state_dict(self):
        return {
            "step": self.step,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "cfg": self.cfg.to_dict(),
            "best_metrics": self.best_metrics
        }

    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.optimizer.load_state_dict(state_dict["optimizer"])
        self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
        self.step = state_dict["step"]
        self.best_metrics = state_dict["best_metrics"]

    def load(self, filename):
        print("Loading from", filename)
        self.load_state_dict(torch.load(filename, map_location=self.get_device(), weights_only=False))
        print("Finished loading")

    def save(self, filename):
        if self.is_master:
            print("Saving to", filename)
            torch.save(self.state_dict(), filename)
            print("Finished saving")

    def last_ckpt_filename(self):
        return f"{self.cfg.trainer_config.model_dir}/model_last.pt"

    def step_ckpt_filename(self):
        return f"{self.cfg.trainer_config.model_dir}/model_step_{self.step}.pt"

    def best_ckpt_filename(self, metric):
        metric = metric.replace("/", "_")
        return f"{self.cfg.trainer_config.model_dir}/model_best_{metric}.pt"

    def lr_scheduler_step(
        self, loss: torch.Tensor, after_epoch_validation: bool = False
    ):
        if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            if after_epoch_validation:
                self.lr_scheduler.step(loss)
            else:
                pass
        else:
            if after_epoch_validation:
                pass
            else:
                self.lr_scheduler.step()

    def grad_norms(self):
        total_norm = 0.0
        for p in self.model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1.0 / 2)
        return total_norm

    # List on i-th rank contains entries of indices i + k * world_size
    # Should be a small list, we don't worry about performance here
    def reduce_list_across_ranks(self, lst, target_size):
        device = self.get_device()
        world_size = self.cfg.trainer_config.world_size
        block_size = (target_size + world_size - 1) // world_size
        tensor = torch.zeros((block_size,), device=device)
        tensor_list = [torch.zeros((block_size,), device=device) for _ in range(world_size)]
        for i in range(len(lst)):
            tensor[i] = lst[i]
        dist.all_gather(tensor_list, tensor)
        result = [0.0 for _ in range(target_size)]
        for i in range(world_size):
            for k in range(block_size):
                idx = i + k * world_size
                if idx < target_size:
                    result[idx] = tensor_list[i][k].item()
        return result

    def test(self):
        if self.is_master:
            print("Start testing...")

        for dataset in self.cfg.test_datasets:
            label = dataset["label"]
            dataset_path = os.path.expandvars(dataset["dataset_path"])
            kwargs = {key: value for key, value in dataset.items() if key != "label" and key != "dataset_path"}
            mse_list, ber_list = eval_model(self.model,
                                            self.cfg,
                                            dataset_path,
                                            world_size=self.cfg.trainer_config.world_size,
                                            rank=self.rank,
                                            device=self.get_device(),
                                            **kwargs,
                                            silent=not self.is_master)
            sinrs = load_meta(dataset_path)["sinr"]
            sinr_cnt = len(sinrs)
            mse_list = self.reduce_list_across_ranks(mse_list, sinr_cnt)
            ber_list = self.reduce_list_across_ranks(ber_list, sinr_cnt)
            if self.is_master:
                print("MSE:", mse_list)
                print("BER:", ber_list)
                fig_mse = plot_mse(sinrs, mse_list, f"MSE for {label}")
                wandb.log({f"test/{label}_mse": fig_mse}, step=self.step)
                fig_ber = plot_ber(sinrs, ber_list, f"BER for {label}")
                wandb.log({f"test/{label}_ber": fig_ber}, step=self.step)


    def train(self):
        device = self.get_device()
        epochs_done = 0
        while True:
            iterable = (
                limit_iterable(self.train_dataloader, 1)
                if self.cfg.testmode
                else self.train_dataloader
            )
            if self.rank == 0:
                iterable = tqdm(
                    iterable,
                    desc=(
                        f"Training ({self.step}"
                        f" / {self.cfg.trainer_config.max_steps})"
                    ),
                )
            for _, inputs in enumerate(iterable):
                # TODO: Add option for gradient accumulation
                inputs = nested_to_device(inputs, device)
                loss = self.train_step(inputs, logging_rank=self.rank == 0)
                self.lr_scheduler_step(loss, after_epoch_validation=False)

                # Check for NaNs
                if torch.isnan(loss).any():
                    raise RuntimeError(f"Detected NaN loss at step {self.step}.")

                self.step += 1

                if self.step % self.cfg.trainer_config.save_every == 0:
                    self.save(self.step_ckpt_filename())

                if self.step == self.cfg.trainer_config.max_steps:
                    return

            loss = self.validate()
            self.lr_scheduler_step(loss, after_epoch_validation=True)

            epochs_done += 1
            if epochs_done % self.cfg.test_every_epochs == 0 or self.cfg.testmode:
                self.test()

    def get_loss_windows_llm(self, mixture, target):
        mixture = data_transforms.stacked_to_windows_with_context(mixture, self.cfg.window_size, self.cfg.context_size)
        input = data_transforms.stacked_to_windows_with_context(target, self.cfg.window_size, self.cfg.context_size)
        input = torch.roll(input, shifts=1, dims=1)
        preds = self.model(input=input, cond=mixture, apply_start_token=True)
        target = data_transforms.stacked_to_windows(target, self.cfg.window_size)
        return F.mse_loss(preds, target), {}

    def get_loss_windows(self, mixture, target):
        mixture = data_transforms.stacked_to_windows_with_context(mixture, self.cfg.window_size, self.cfg.context_size)
        target = data_transforms.stacked_to_windows(target, self.cfg.window_size)
        preds = self.model(input=mixture)
        return F.mse_loss(preds, target), {}

    def get_loss_quantized(self, mixture, target):
        target = data_transforms.stacked_to_interleaving(target)
        preds = self.model(cond=self.prepare_mixture(mixture))
        target_tokens = self.model.encode(target)
        recon = self.model.decode_logits(preds)
        return token_cross_entropy(preds, target_tokens), {
            "train/mse": F.mse_loss(recon, target)
        }

    def get_loss_quantized_llm(self, mixture, target):
        target = data_transforms.stacked_to_interleaving(target)
        target_tokens = self.model.encode(target)
        preds = self.model(cond=self.prepare_mixture(mixture), target=target_tokens)
        recon = self.model.decode_logits(preds)
        return token_cross_entropy(preds, target_tokens), {
            "train/shift_mse": F.mse_loss(recon, target)
        }

    def get_loss_wavenet(self, mixture, target):
        mixture = data_transforms.stacked_to_wavenet(mixture)
        target = data_transforms.stacked_to_wavenet(target)
        preds = self.model(mixture)
        return F.mse_loss(preds, target), {}

    def get_loss(self, mixture, target):
        if self.cfg.model_type == ModelType.WAVENET:
            return self.get_loss_wavenet(mixture, target)
        elif self.cfg.model_type == ModelType.WINDOWS:
            return self.get_loss_windows(mixture, target)
        elif self.cfg.model_type == ModelType.WINDOWS_LLM:
            return self.get_loss_windows_llm(mixture, target)
        elif self.cfg.model_type == ModelType.QUANTIZED:
            return self.get_loss_quantized(mixture, target)
        elif self.cfg.model_type == ModelType.QUANTIZED_LLM:
            return self.get_loss_quantized_llm(mixture, target)
        else:
            raise ValueError("Unsupported model type")

    def train_step(
        self,
        inputs: List[torch.Tensor],
        logging_rank: bool = False,
    ) -> torch.Tensor:
        self.optimizer.zero_grad()

        mixture = inputs["mixture"]
        target = inputs["target"]

        with self.autocast:
            loss, metrics = self.get_loss(mixture, target)
        self.scaler.scale(loss).backward()
        self.scaler.unscale_(self.optimizer)
        if logging_rank and self.step % self.cfg.trainer_config.log_every == 0:
            wandb.log(
                {"train/grad_norms": self.grad_norms()},
                self.step,
            )
        if self.cfg.trainer_config.clip_max_norm > 0:
            nn.utils.clip_grad_norm_(
                self.model.parameters(), self.cfg.trainer_config.clip_max_norm
            )
        self.scaler.step(self.optimizer)
        self.scaler.update()

        if logging_rank and self.step % self.cfg.trainer_config.log_every == 0:
            wandb.log(
                {"train/loss": loss, "train/lr": self.optimizer.param_groups[0]["lr"]}
                | metrics,
                self.step,
            )

        return loss

    def val_step_windows(self, mixture, target):
        mixture = data_transforms.stacked_to_windows_with_context(mixture, self.cfg.window_size, self.cfg.context_size)
        target = data_transforms.stacked_to_windows(target, self.cfg.window_size)
        preds = self.model(input=mixture)
        mse_loss = F.mse_loss(preds, target)
        preds = data_transforms.windows_to_complex(preds)
        return {"val/loss": mse_loss}, preds

    def val_step_windows_llm(self, mixture, target):
        mixture = data_transforms.stacked_to_windows_with_context(mixture, self.cfg.window_size, self.cfg.context_size)
        input = data_transforms.stacked_to_windows_with_context(target, self.cfg.window_size, self.cfg.context_size)
        input = torch.roll(input, shifts=1, dims=1)
        preds = self.model(input=input, cond=mixture, apply_start_token=True)
        target = data_transforms.stacked_to_windows(target, self.cfg.window_size)
        shift_mse = F.mse_loss(preds, target)
        recon = self.model.generate(cond=mixture, window_size=self.cfg.window_size, context_size=self.cfg.context_size)
        recon_mse = F.mse_loss(recon, target)
        recon = data_transforms.windows_to_complex(recon)
        return {"val/loss": shift_mse, "val/recon_mse": recon_mse}, recon

    def prepare_mixture(self, mixture):
        if self.cfg.quantize_input:
            return data_transforms.stacked_to_interleaving(mixture)
        else:
            return data_transforms.stacked_to_windows_with_context(mixture, self.cfg.window_size, self.cfg.context_size)

    def val_step_quantized(self, mixture, target):
        preds = self.model(cond=self.prepare_mixture(mixture))
        target = data_transforms.stacked_to_interleaving(mixture)
        target_tokens = self.model.encode(target)
        cur_loss = token_cross_entropy(preds, target_tokens)
        output = self.model.decode_softmax(preds)
        cur_mse_loss = F.mse_loss(output, target)
        output = data_transforms.interleaving_to_complex(output)
        return {"val/mse": cur_mse_loss, "val/loss": cur_loss}, output

    def val_step_quantized_llm(self, mixture, target, beam_k=1):
        target = data_transforms.stacked_to_interleaving(target)
        target_tokens = self.model.encode(target)
        preds = self.model(cond=self.prepare_mixture(mixture), target=target_tokens)
        loss = token_cross_entropy(preds, target_tokens)
        output = self.model.decode_logits(preds)
        shift_mse = F.mse_loss(output, target)
        recon = self.model.generate(cond=self.prepare_mixture(mixture), beam_k=beam_k).float()
        recon_mse = F.mse_loss(recon, target)
        recon = data_transforms.interleaving_to_complex(recon)
        return {"val/loss": loss, "val/shift_mse": shift_mse, "val/recon_mse": recon_mse}, recon

    def val_step_wavenet(self, mixture, target):
        mixture = data_transforms.stacked_to_wavenet(mixture)
        target = data_transforms.stacked_to_wavenet(target)
        preds = self.model(mixture)
        loss = F.mse_loss(preds, target)
        preds = data_transforms.wavenet_to_complex(preds)
        return {"val/loss": loss}, preds

    def val_step(self, mixture, target):
        if self.cfg.model_type == ModelType.WAVENET:
            return self.val_step_wavenet(mixture, target)
        elif self.cfg.model_type == ModelType.WINDOWS_LLM:
            return self.val_step_windows_llm(mixture, target)
        elif self.cfg.model_type == ModelType.WINDOWS:
            return self.val_step_windows(mixture, target)
        elif self.cfg.model_type == ModelType.QUANTIZED:
            return self.val_step_quantized(mixture, target)
        elif self.cfg.model_type == ModelType.QUANTIZED_LLM:
            return self.val_step_quantized_llm(mixture, target, beam_k=self.cfg.beam_k)
        else:
            raise ValueError("Unsupported model type")

    def cut_to_sync_offset(self, x, offsets, ber_sync):
        # x has shape [B, s] of complex values.
        # starting pos of x[i] in original signal from the dataset is offsets[i]
        # to compute BER, we need ber_sync to divide starting pos and length
        # so we cut off prefix and suffix to satisfy that
        new_len = x.shape[1] - ber_sync
        start_pos = ber_sync - torch.remainder(offsets, ber_sync)
        slice_idx = start_pos[:, None] + torch.arange(new_len, device=x.device)
        return torch.take_along_dim(x, slice_idx, dim=1)


    def get_device(self):
        return next(self.model.parameters()).device

    def update_checkpoints(self, metrics):
        metrics_updated = []
        for metric, maximize in self.cfg.ckpt_metrics:
            if metric not in self.best_metrics or ((metrics[metric] < self.best_metrics[metric]) ^ maximize):
                metrics_updated.append(metric)
                self.best_metrics[metric] = metrics[metric]
        self.save(self.last_ckpt_filename())
        for metric in metrics_updated:
            self.save(self.best_ckpt_filename(metric))

    @torch.no_grad()
    def validate(self) -> Tuple[torch.Tensor, torch.Tensor]:
        device = self.get_device()
        self.model.eval()
        metrics = {}
        ber = 0.0

        iterable = (
            limit_iterable(self.val_dataloader, 1)
            if self.cfg.testmode
            else self.val_dataloader
        )
        if self.rank == 0:
            iterable = tqdm(
                iterable,
                desc=f"Running validation after step {self.step}",
            )
        for inputs in iterable:
            with self.autocast:
                inputs = nested_to_device(inputs, device)

                mixture = inputs["mixture"]
                target = inputs["target"]

                cur_metrics, pred_waveform = self.val_step(mixture, target)
                pred_waveform = pred_waveform.to(torch.complex64)
                for metric_name, metric_val in cur_metrics.items():
                    metric_val = metric_val * mixture.shape[0] / len(self.val_dataset)
                    if metric_name in metrics:
                        metrics[metric_name] += metric_val
                    else:
                        metrics[metric_name] = metric_val

            target_waveform = data_transforms.stacked_to_complex(target)

            ber_sync = self.cfg.ber_sync
            if ber_sync != 1:
                offsets = inputs["offset"].to(device)
                pred_waveform = self.cut_to_sync_offset(pred_waveform, offsets, ber_sync)
                target_waveform = self.cut_to_sync_offset(target_waveform, offsets, ber_sync)

            demod_soi = get_demod_soi(self.cfg.soi_type)
            bit_est, _ = demod_soi(pred_waveform.cpu().numpy())
            bit_gt, _ = demod_soi(target_waveform.cpu().numpy())
            ber += (
                np.mean(bit_est.numpy() != bit_gt.numpy())
                * bit_est.shape[0]
                / len(self.val_dataset)
            )

        metrics["val/ber"] = torch.tensor(ber, device=device)
        if self.cfg.trainer_config.distributed:
            for metric_name, val in metrics.items():
                dist.reduce(val, 0, dist.ReduceOp.SUM)

        if self.rank == 0:
            waveform_to_plot = (
                rearrange(torch.view_as_real(pred_waveform[0]), "a c -> c a")
                .cpu()
                .numpy()[0][:1024]
            )
            wandb.log(metrics, self.step)
            fig, ax = plt.subplots()
            ax.plot(waveform_to_plot)
            wandb.log({"val/waveform": fig}, self.step)

        self.model.train()

        if not self.cfg.testmode:
            self.update_checkpoints(metrics)

        loss = metrics["val/loss"]
        return loss


def _train_impl(model: nn.Module, cfg: ConfigDict):
    torch.backends.cudnn.benchmark = True

    learner = Learner(model, cfg)
    learner.train()


def train(cfg: ConfigDict):
    """Training on a single GPU."""
    model, _ = image_transforms_builder.build(cfg.model_config)
    model.cuda()
    _train_impl(model, cfg)


def init_distributed():
    """Initialize distributed training on multiple GPUs."""
    torch.distributed.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def train_distributed(cfg: ConfigDict):
    """Training on multiple GPUs."""
    init_distributed()
    model, _ = image_transforms_builder.build(cfg.model_config)
    model.to(int(os.environ["LOCAL_RANK"]))
    model = MyDistributedDataParallel(
        model, device_ids=[int(os.environ["LOCAL_RANK"])], find_unused_parameters=True
    )
    _train_impl(model, cfg)
    torch.distributed.destroy_process_group()
