import os
import torch
import torchvision
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rich.console import Console
from rich.progress import Progress
from sklearn.metrics import (
    confusion_matrix, f1_score, roc_auc_score, roc_curve
)
from sklearn.preprocessing import label_binarize
import argparse
import model.networks as nets
from dataset import classify_data

console = Console()

def get_args():
    parser = argparse.ArgumentParser(description='test parameters')
    parser.add_argument('--num_class', type=int)
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--device', type=str)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--test_data_csv', type=str)
    parser.add_argument('--image_size', type=int)
    parser.add_argument('--in_channels', type=int)
    parser.add_argument('--parameter', type=str)
    parser.add_argument('--checkpoint', type=str)
    parser.add_argument('--dataset_name', type=str, help='|sign_number|sign_mnist|sign_mnist_real')
    parser.add_argument('--output_dir', type=str)
    parser.add_argument('--test_type', type=str, help='|confusion_matrix_and_ROC|scatter_img')
    return parser.parse_args()

def get_model(num_class, model_name, in_size, in_channels, parameter):
    if model_name == 'RacoNetClassify':
        model_train = nets.RacoNetClassify(
            coder_in_channels=in_channels,
            coder_out_channels=8,
            coder_in_size=in_size*2,
            coder_out_size=64,
            coder_parameters=parameter,
            classify_in_channels=in_channels,
            classify_num=num_class,
            classify_out_channels=8,
            classify_out_size=64
        )
    else:
        raise ValueError(f'ERROR!!! NO {model_name} !!!')
    return model_train

def test_model(
    num_class,
    model_name,
    device,
    batch_size,
    test_data_csv,
    image_size,
    in_channels,
    parameter,
    checkpoint,
    dataset_name,
    output_dir,
    test_type
):
    from sklearn.metrics import (accuracy_score, recall_score, precision_score, cohen_kappa_score)


    if device.lower() == 'cuda' and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")


    plt.rcParams.update({
        'font.size': 24,
        'axes.titlesize': 28,
        'axes.labelsize': 26,
        'xtick.labelsize': 24,
        'ytick.labelsize': 24,
        'legend.fontsize': 22,
        'figure.titlesize': 30
    })


    transform_img = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((256, 256))
    ])


    test_dataset = classify_data.ClassifyDataSet(
        data_csv=test_data_csv,
        transform_img=transform_img
    )
    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=int(batch_size),
        shuffle=False,
        drop_last=False,
        num_workers=4
    )


    model = get_model(num_class, model_name, image_size, in_channels, parameter)


    old_weights = torch.load(checkpoint, map_location='cpu')
    new_weights = {}
    for k, v in old_weights.items():

        if '.freq_conv.' in k:
            new_k = k.replace('.freq_conv.', '.freq_module.conv.')
        elif '.freq_norm.' in k:
            new_k = k.replace('.freq_norm.', '.freq_module.norm.')
        else:
            new_k = k
        if 'module.' in new_k:
            new_k = new_k.replace('module.', '')
        new_weights[new_k] = v

    model.load_state_dict(new_weights, strict=False)
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []
    all_probs = []

    output_dir = os.path.join(output_dir, test_type, model_name, dataset_name)
    os.makedirs(output_dir, exist_ok=True)
    console.log("[green]start testing model...")
    with torch.no_grad(), Progress() as progress:
        test_task = progress.add_task("[green]Testing Model...", total=len(test_data_loader))
        for imgs, labs, truths in test_data_loader:
            imgs = imgs.to(device, dtype=torch.float32)
            labs = labs.to(device, dtype=torch.long)

            if model_name == 'CNNTRClassify':
                _, _, outputs = model(imgs)
            else:
                outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labs.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            progress.update(test_task, advance=1)

    if test_type == "confusion_matrix_and_ROC":
        from sklearn.metrics import (confusion_matrix, f1_score, roc_auc_score)
        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', annot_kws={"size": 16})
        plt.title('Confusion Matrix', fontsize=16)
        plt.xlabel('Predicted', fontsize=16)
        plt.ylabel('True', fontsize=16)
        
        cm_path = os.path.join(output_dir, 'confusion_matrix.png')
        plt.savefig(cm_path)
        plt.close()

        f1_macro = f1_score(all_labels, all_preds, average='macro')
        f1_per_class = f1_score(all_labels, all_preds, average=None)

        y_true = label_binarize(all_labels, classes=list(range(num_class)))
        y_score = np.array(all_probs)
        try:
            auc_macro = roc_auc_score(y_true, y_score, average='macro', multi_class='ovr')
        except ValueError as e:
            print(f"Error computing AUC: {e}")
            auc_macro = None
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(num_class):
            fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_score[:, i])
            try:
                roc_auc[i] = roc_auc_score(y_true[:, i], y_score[:, i])
            except ValueError:
                roc_auc[i] = 0.0
        print(f"F1 Score (Macro): {f1_macro}")
        if auc_macro is not None:
            print(f"AUC (Macro): {auc_macro}")
        plt.figure(figsize=(12, 8))
        for i in range(num_class):
            plt.plot(fpr[i], tpr[i], lw=2, label=f'Class {i} (AUC = {roc_auc[i]:0.2f})')
        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.xlabel('')
        plt.ylabel('')
        plt.title('')
        plt.legend(loc='lower right', fontsize=9, ncol=2)
        roc_path = os.path.join(output_dir, 'roc_curves.png')
        plt.savefig(roc_path)
        plt.close()

        plt.figure(figsize=(14, 8))
        sns.barplot(x=list(range(num_class)), y=f1_per_class)
        plt.xlabel('Class')
        plt.ylabel('F1 Score')
        plt.title('F1 Score per Class')
        plt.ylim(0, 1)
        for index, value in enumerate(f1_per_class):
            plt.text(index, value + 0.01, f"{value:.2f}", ha='center', fontsize=26)
        f1_path = os.path.join(output_dir, 'f1_scores.png')
        plt.savefig(f1_path)
        plt.close()

        metrics = ['F1 Score (Macro)', 'AUC (Macro)']
        values = [f1_macro, auc_macro if auc_macro is not None else 0]
        plt.figure(figsize=(6, 4))
        sns.barplot(x=metrics, y=values)
        for index, value in enumerate(values):
            plt.text(index, value + 0.01, f"{value:.4f}", ha='center', fontsize=20)
        plt.title('Overall Evaluation Metrics')
        plt.ylim(0, 1)
        metrics_path = os.path.join(output_dir, 'overall_metrics.png')
        plt.savefig(metrics_path)
        plt.close()

        print(f"Confusion matrix, ROC curves, F1 scores, and metrics have been saved to:\n{output_dir}")

    elif test_type == "scatter_img":
        features = np.array(all_probs)
        labels = np.array(all_labels)
        scatter_font_size = 24
        scatter_legend_font_size = 22

        from sklearn.decomposition import PCA
        from sklearn.manifold import TSNE, Isomap
        import umap.umap_ as umap

        methods = {
            'PCA': PCA(n_components=2),
            't-SNE': TSNE(
                n_components=2,
                random_state=42,
                perplexity=6,
                early_exaggeration=50,
                learning_rate=200,
                n_iter=1000,
            ),
            'UMAP': umap.UMAP(
                n_components=2,
                n_neighbors=15,
                min_dist=3,
                spread=6,
                random_state=50
            )
        }

        num_colors = num_class
        cmap = plt.get_cmap("tab10")
        for method_name, reducer in methods.items():
            reduced_features = reducer.fit_transform(features)
            
            plt.figure(figsize=(20, 20))
            ax = plt.gca()

            for cls in np.unique(labels):
                idx = labels == cls
                ax.scatter(
                    reduced_features[idx, 0],
                    reduced_features[idx, 1],
                    label=f'Class {cls}',
                    color=cmap(int(cls) % 10),
                    alpha=0.7,
                    edgecolors='w',
                    s=80
                )
                center = reduced_features[idx].mean(axis=0)
                ax.plot(center[0], center[1], marker='x', markersize=15, color='red', mew=3)

            plt.title(f"{method_name}", fontsize=scatter_font_size)
            plt.xlabel("Dimension 1", fontsize=scatter_font_size)
            plt.ylabel("Dimension 2", fontsize=scatter_font_size)
            plt.legend(fontsize=scatter_legend_font_size, bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.tight_layout(rect=[0, 0, 0.85, 1])
            
            ax.set_aspect('equal', adjustable='box')

            save_path = os.path.join(output_dir, f"scatter_{method_name}.png")
            plt.savefig(save_path)
            console.log(f"[green]saved {method_name} : {save_path}")
            plt.close()



if __name__ == '__main__':
    args = get_args()
    test_model(
        num_class=args.num_class,
        model_name=args.model_name,
        device=args.device,
        batch_size=args.batch_size,
        test_data_csv=args.test_data_csv,
        image_size=args.image_size,
        in_channels=args.in_channels,
        parameter=args.parameter,
        checkpoint=args.checkpoint,
        dataset_name=args.dataset_name,
        output_dir=args.output_dir,
        test_type=args.test_type
    )
