import os
from dataclasses import dataclass
from typing import List, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score, auc, confusion_matrix, roc_curve, f1_score
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from utils import seed_everything

__all__ = ["ContrastiveMLP", "train", "validate"]

random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)


class ContrastiveMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, embeddings):
        # embeddings: (batch_size, seq_len, input_dim)
        scores = self.mlp(embeddings).squeeze(-1)  # (batch_size, seq_len)
        return scores

    def get_score(
        self, sequences: Union[torch.Tensor, List[torch.Tensor]]
    ) -> torch.Tensor:

        with torch.no_grad():
            # embeddings: (batch_size, seq_len, input_dim)
            device = next(self.parameters()).device

            padded = pad_sequence(sequences, batch_first=True)
            lengths = torch.tensor([s.size(0) for s in sequences], device=device)

            scores = self.mlp(padded).squeeze(-1)  # (batch_size, seq_len)
            mask = torch.arange(scores.size(1), device=device).expand(
                len(lengths), -1
            ) < lengths.unsqueeze(1)

            masked_scores = scores.masked_fill(~mask, -float("inf"))
            max_logits, _ = masked_scores.max(dim=-1)
            return nn.functional.sigmoid(max_logits)


@dataclass
class TrainResult:
    model: ContrastiveMLP
    args_dict: dict
    # loss_list: List[float]
    train_acc: float = 0.0
    train_auroc: float = 0.0
    train_threshold: float = 0.0
    train_f1: float = 0.0
    valid_acc: float = 0.0
    valid_auroc: float = 0.0
    valid_threshold: float = 0.0
    valid_f1: float = 0.0


@dataclass
class AurocResult:
    fpr: np.ndarray
    tpr: np.ndarray
    thresholds: np.ndarray
    accuracies: np.ndarray
    f1_scores: np.ndarray
    roc_auc: float
    best_threshold: float
    best_acc: float
    best_f1: float
    best_f1_threshold: float
    confusion_df: pd.DataFrame



def _prepare_data(num_samples, input_dim):
    sentences = [
        torch.randn(torch.randint(5, 20, ()), input_dim) for _ in range(num_samples)
    ]
    flags = torch.randint(0, 2, (num_samples,)).bool()
    return list(zip(sentences, flags))

def calculate_auroc(scores, flags, acc_threshold=None) -> AurocResult:
    y_scores = scores.cpu().numpy()
    y_true = flags.cpu().numpy()
    

    if not np.all(np.isfinite(y_scores)):
        nan_count = np.isnan(y_scores).sum()
        inf_count = np.isinf(y_scores).sum()
        print(f"Warning: The input score contains {nan_count} Nans and {inf_count} Inf values")
        
        max_val = np.finfo(np.float32).max
        min_val = np.finfo(np.float32).min
        
        y_scores = np.nan_to_num(
            y_scores, 
            nan=0.0,
            posinf=max_val,
            neginf=min_val
        )
        
        print(f"The illegal values have been replaced with safe values: NaN→0.0, +Inf→{max_val:.3e}, -Inf→{min_val:.3e}")
    
    fpr, tpr, thresholds = roc_curve(y_true=y_true, y_score=y_scores)
    roc_auc = auc(fpr, tpr)
    
    accuracies = []
    f1_scores = []
    
    for threshold in thresholds:
        y_pred = (y_scores > threshold).astype(int)
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        accuracies.append(acc)
        f1_scores.append(f1)


    if acc_threshold is None:
        best_acc_idx = np.argmax(accuracies)
        best_threshold = thresholds[best_acc_idx].item()
        best_acc = accuracies[best_acc_idx]
        
        best_f1_idx = np.argmax(f1_scores)
        best_f1_threshold = thresholds[best_f1_idx].item()
        best_f1 = f1_scores[best_f1_idx]
    else:
        best_threshold = acc_threshold
        y_pred = (y_scores > acc_threshold).astype(int)
        best_acc = accuracy_score(y_true, y_pred)
        best_f1 = f1_score(y_true, y_pred, zero_division=0)
        best_f1_threshold = acc_threshold

    y_true_labels = [1 if b else 0 for b in y_true]
    y_pred_labels = (y_scores > best_threshold).astype(int)
    
    cm = confusion_matrix(y_true_labels, y_pred_labels)
    confusion_df = pd.DataFrame(
        cm, index=["True 0", "True 1"], columns=["Pred 0", "Pred 1"]
    )
    
    return AurocResult(
        fpr=fpr,
        tpr=tpr,
        thresholds=thresholds,
        accuracies=np.array(accuracies),
        f1_scores=np.array(f1_scores),
        roc_auc=roc_auc,
        best_threshold=best_threshold,
        best_acc=best_acc,
        best_f1=best_f1,
        best_f1_threshold=best_f1_threshold,
        confusion_df=confusion_df,
    )


def validate(model, data, batch_size: int = None, acc_threshold=None) -> AurocResult:

    model.eval()
    device = next(
        model.parameters()
    ).device  

    with torch.no_grad():
        if batch_size is None:
            sentences = [s.to(device) for s, _ in data]
            flags = torch.stack([f for _, f in data]).to(device)

            padded = pad_sequence(sentences, batch_first=True)
            lengths = torch.tensor([s.size(0) for s in sentences], device=device)

            scores = model(padded)
            mask = torch.arange(scores.size(1), device=device).expand(
                len(lengths), -1
            ) < lengths.unsqueeze(1)
            masked_scores = scores.masked_fill(~mask, -float("inf"))
            max_scores, _ = masked_scores.max(dim=-1)

        else:
            all_max_scores = []  
            all_flags = []  

            for i in range(0, len(data), batch_size):
                batch = data[i : i + batch_size]

                sentences = [s.to(device) for s, _ in batch]
                flags_batch = torch.stack([f for _, f in batch]).to(device)

                padded = pad_sequence(sentences, batch_first=True)
                lengths = torch.tensor([s.size(0) for s in sentences], device=device)

                scores = model(padded)

                mask = torch.arange(scores.size(1), device=device).expand(
                    len(lengths), -1
                ) < lengths.unsqueeze(1)
                masked_scores = scores.masked_fill(~mask, -float("inf"))
                max_scores_batch, _ = masked_scores.max(dim=-1)

                all_max_scores.append(max_scores_batch)
                all_flags.append(flags_batch)

            max_scores = torch.cat(all_max_scores)
            flags = torch.cat(all_flags)

        auroc_res = calculate_auroc(max_scores, flags, acc_threshold=acc_threshold)

    model.train()
    return auroc_res


def plt_auroc(fpr, tpr, roc_auc, save_path):
    plt.figure()
    plt.plot(
        fpr, tpr, color="darkorange", lw=2, label="ROC curve (area = %0.2f)" % roc_auc
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver Operating Characteristic")
    plt.legend(loc="lower right")
    # plt.show()
    plt.savefig(save_path)


def plt_loss(step_list, loss_list, title, save_path):
    plt.figure()
    plt.plot(step_list, loss_list)
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title(title)
    # plt.show()
    plt.savefig(save_path)


def train(
    train_data: List[torch.Tensor],
    valid_data: List[torch.Tensor],
    input_dim: int,
    hidden_dim: int,
    batch_size: int,
    step_log: int,
    epochs: int,
    epoch_log: int = 1,
    max_grad_norm: float = 1.0,
    random_seed: int = 455,
    lr: float = 1e-3,
    weight_decay: float = 3e-4,
    neft_alpha: float = 0.0,
    device=None,
    auroc_img_save_folder: str = None,
    model_cache: ContrastiveMLP = None
) -> TrainResult:

    args_dict = locals()
    
    assert all(s.size(0) > 0 for s, _ in train_data), "训练集包含空序列"
    assert all(s.size(0) > 0 for s, _ in valid_data), "验证集包含空序列"

    seed_everything(seed=random_seed)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ContrastiveMLP(input_dim, hidden_dim=hidden_dim).to(device) if model_cache is None else model_cache.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=epochs * len(train_data) // batch_size
    )

    total_step = 0
    total_loss_list = []
    for epoch in tqdm(range(epochs)):
        indices = torch.randperm(len(train_data))
        loss_list = []
        for batch_idx in range(0, len(train_data), batch_size):
            model.train()
            batch_indices = indices[batch_idx : batch_idx + batch_size]
            batch = [train_data[i] for i in batch_indices]

            sentences = [s.to(device) for s, _ in batch]
            flags = torch.stack([f for _, f in batch]).to(device)

            padded = pad_sequence(sentences, batch_first=True)
            lengths = torch.tensor([s.size(0) for s in sentences], device=device)

            if neft_alpha > 0 and model.training:
                dims = torch.tensor(padded.size(1) * padded.size(2), device=device)
                mag_norm = neft_alpha / torch.sqrt(dims)
                padded = padded + torch.zeros_like(padded).uniform_(-mag_norm, mag_norm)

            scores = model(padded)
            mask = torch.arange(scores.size(1), device=device).expand(
                len(lengths), -1
            ) < lengths.unsqueeze(1)
            masked_scores = scores.masked_fill(~mask, -float("inf"))
            max_scores, _ = masked_scores.max(dim=-1)

            loss_elements = -(
                flags.float() * torch.log(max_scores + 1e-6)
                + (1 - flags.float()) * torch.log(1 - max_scores + 1e-6)
            )
            loss = loss_elements.mean()
            
            loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            optimizer.step()
            scheduler.step()  

            total_step += 1
            if total_step % step_log == 0:
                print(
                    f"Step {total_step} | Loss: {loss.item():.4f} | LR: {scheduler.get_last_lr()[0]:.2e}"
                )
        print(
            f"Epoch {epoch + 1} | Loss: {sum(loss_list) / len(loss_list):.4f}| LR: {scheduler.get_last_lr()[0]:.2e}"
        )
        total_loss_list.extend(loss_list)

        if (epoch + 1) % epoch_log == 0 or epoch == epochs - 1:
            print("Validating...")
            train_auroc_res = validate(
                model,
                train_data,
                batch_size=batch_size,
                acc_threshold=None,  
            )
            valid_auroc_res = validate(
                model,
                valid_data,
                batch_size=batch_size,
                acc_threshold=train_auroc_res.best_threshold,  
            )
            print(
                f"Train AUROC:{train_auroc_res.roc_auc:.2%} Acc: {train_auroc_res.best_acc:.2%}(Th={train_auroc_res.best_threshold:.3f}) F1: {train_auroc_res.best_f1:.2%}(Th={train_auroc_res.best_f1_threshold:.3f})"
            )
            print(
                f"Valid AUROC:{valid_auroc_res.roc_auc:.2%} Acc: {valid_auroc_res.best_acc:.2%}(Th={valid_auroc_res.best_threshold:.3f}) F1: {valid_auroc_res.best_f1:.2%}(Th={valid_auroc_res.best_f1_threshold:.3f})"
            )
    if auroc_img_save_folder is not None:
        
        if not os.path.exists(auroc_img_save_folder):
            os.makedirs(auroc_img_save_folder)
        plt_auroc(
            fpr=valid_auroc_res.fpr,
            tpr=valid_auroc_res.tpr,
            roc_auc=valid_auroc_res.roc_auc,
            save_path=os.path.join(
                auroc_img_save_folder, f"[epoch{epoch + 1}]valid.png"
            ),
        )
        plt_auroc(
            fpr=train_auroc_res.fpr,
            tpr=train_auroc_res.tpr,
            roc_auc=train_auroc_res.roc_auc,
            save_path=os.path.join(
                auroc_img_save_folder, f"[epoch{epoch + 1}]train.png"
            ),
        )
        plt_loss(
            step_list=range(len(total_loss_list)),
            loss_list=total_loss_list,
            title="Loss",
            save_path=os.path.join(auroc_img_save_folder, "train_loss.png"),
        )
        
        np.savetxt(f'{auroc_img_save_folder}/train_acc.csv', np.column_stack((train_auroc_res.thresholds, train_auroc_res.accuracies)), delimiter=',', header='thresholds,accuracies', comments='', fmt='%f')
        np.savetxt(f'{auroc_img_save_folder}/valid_acc.csv', np.column_stack((valid_auroc_res.thresholds, valid_auroc_res.accuracies)), delimiter=',', header='thresholds,accuracies', comments='', fmt='%f')
        # 保存f1到CSV文件
        np.savetxt(f'{auroc_img_save_folder}/train_f1.csv', np.column_stack((train_auroc_res.thresholds, train_auroc_res.f1_scores)), delimiter=',', header='thresholds,f1_scores', comments='', fmt='%f')
        np.savetxt(f'{auroc_img_save_folder}/valid_f1.csv', np.column_stack((valid_auroc_res.thresholds, valid_auroc_res.f1_scores)), delimiter=',', header='thresholds,f1_scores', comments='', fmt='%f')

    print("=" * 10)

    print("Training parameters:")
    for param_name, param_value in args_dict.items():
        if param_name in ["train_data", "valid_data"]:
            args_dict[param_name] = f"len({len(param_value)})"
            print(f"{param_name}: {args_dict[param_name]}")
        else:
            print(f"{param_name}: {param_value}")
    print("-" * 10)
    print(
        f"Train AUROC:{train_auroc_res.roc_auc:.2%} Threshold: {train_auroc_res.best_threshold:.3f} Acc: {train_auroc_res.best_acc:.2%}"
    )
    print(
        f"Valid AUROC:{valid_auroc_res.roc_auc:.2%} Threshold: {valid_auroc_res.best_threshold:.3f} Acc: {valid_auroc_res.best_acc:.2%} "
    )
    print("-" * 10)
    print("train_confusion_matrix:")
    print(train_auroc_res.confusion_df)
    print("-" * 10)
    print("valid_confusion_matrix:")
    print(valid_auroc_res.confusion_df)
    print("=" * 10)
    # return model, total_loss_list
    return TrainResult(
        model=model,
        args_dict=args_dict,
        # loss_list=total_loss_list,
        train_acc=train_auroc_res.best_acc,
        train_auroc=train_auroc_res.roc_auc,
        train_threshold=train_auroc_res.best_threshold,
        valid_acc=valid_auroc_res.best_acc,
        valid_auroc=valid_auroc_res.roc_auc,
        valid_threshold=valid_auroc_res.best_threshold,
    )

