import os
import random
import sys
import warnings
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from learning.model.curve import CURVE

warnings.simplefilter(action="ignore", category=FutureWarning)
sys.path.append(os.path.dirname(sys.path[0]))


class AleatoricCrossEntropyLoss(nn.Module):
    def __init__(self, weight: Optional[torch.Tensor] = None, num_samples: int = 10):
        super().__init__()
        self.weight = weight
        self.num_samples = num_samples

    def forward(self, logits_mu: torch.Tensor, logits_logvar: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logits_logvar)
        loss = 0.0
        for _ in range(self.num_samples):
            eps = torch.randn_like(std)
            sampled_logits = logits_mu + eps * std
            loss = loss + F.cross_entropy(sampled_logits, target, weight=self.weight)
        return loss / float(self.num_samples)


class Trainer:
    def __init__(self, config, wandb_a=None):
        self.config = config
        self.log = wandb_a is not None
        self.wandb = wandb_a if self.log else None

        seed = self.config.training_config.get("seed", None)
        if seed is None:
            seed = random.randint(0, 2**32 - 1)
        self.config.seed = seed

        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

        self.toGPU = lambda x, dtype: torch.as_tensor(x, dtype=dtype, device=self.config.model_config["device"])
        self.initialize_best_metrics()

    @staticmethod
    def gaussian_kl_to_standard_normal(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        return 0.5 * torch.mean(torch.exp(logvar) + mu.pow(2) - 1.0 - logvar)

    def initialize_best_metrics(self):
        self.best_val_loss = 99999.0
        self.best_epoch = 0
        self.best_val_acc = 0.0
        self.best_val_auc = 0.0
        self.best_val_confusion = []
        self.best_val_f1 = 0.0
        self.best_val_mcc = -1.0
        self.best_val_acc_balanced = 0.0
        self.unique_clips = {}

    def split_dataset(self):
        raise NotImplementedError

    def build_model(self):
        self.model = CURVE(self.config).to(self.config.model_config["device"])

        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.config.training_config["learning_rate"],
            weight_decay=self.config.training_config["weight_decay"],
        )

        loss_weight = None
        if hasattr(self, "class_weights") and getattr(self.class_weights, "shape", None) is not None:
            if self.class_weights.shape[0] >= 2:
                loss_weight = self.class_weights.float().to(self.config.model_config["device"])

        self.loss_func = AleatoricCrossEntropyLoss(weight=loss_weight, num_samples=10)

        if not self.config.model_config.get("load_model", False):
            if self.log:
                self.wandb.watch(self.model, log="all")

    def learn(self):
        if self.config.training_config["n_fold"] <= 1 or self.config.location_data.get("transfer_path", None) is not None:
            print("\nRunning Standard Training Loop\n")
            self.train()
        else:
            print(f"\nRunning {self.config.training_config['n_fold']}-Fold Cross Validation Training Loop\n")
            self.cross_valid()

    def train(self):
        raise NotImplementedError

    def inference(self):
        raise NotImplementedError

    def cross_valid(self):
        raise NotImplementedError

    def save_model(self, suffix: str = ""):
        original_path = Path(self.config.model_config["model_save_path"]).resolve()
        if suffix:
            new_filename = original_path.stem + suffix + original_path.suffix
            saved_path = original_path.parent / new_filename
        else:
            saved_path = original_path

        os.makedirs(saved_path.parent, exist_ok=True)
        torch.save(self.model.state_dict(), str(saved_path))

        with open(saved_path.parent / "model_parameters.txt", "w+", encoding="utf-8") as f:
            f.write(str(self.config))
            f.write("\n")
            f.write(" ".join(sys.argv))

    def load_model(self):
        saved_path = Path(self.config.model_config["model_load_path"]).resolve()
        if not saved_path.exists():
            raise FileNotFoundError(f"Failed to load model. Model load path does not exist: {saved_path}")

        self.build_model()

        state_dict = torch.load(str(saved_path), map_location=self.config.model_config["device"])
        model_state_dict = self.model.state_dict()

        filtered_state_dict = {
            k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape
        }
        discarded_keys = [k for k in state_dict.keys() if k not in filtered_state_dict]
        if discarded_keys:
            print(f"Warning: Discarded mismatched layers: {discarded_keys}")

        self.model.load_state_dict(filtered_state_dict, strict=False)
        self.model.eval()

