import argparse
from utils import *
from dataloader import *
from model import *
from itertools import chain
from metrics import *
import torch
import os
import numpy as np
import matplotlib.pyplot as plt

Dataname = 'COIL20'
parser = argparse.ArgumentParser(description='train')
parser.add_argument('--dataset', default=Dataname, help = '[CCV, RGB-D, Cora, ALOI-100, Hdigit, Digit-Product]')
parser.add_argument('--save_model', default=True, help='Saving the model after training.')
parser.add_argument("--learning_rate", default=0.0001)  # 学习率

parser.add_argument("--pre_train", default=10)  # 预训练阶段
parser.add_argument("--pretrain_iter", default=10)
parser.add_argument("--fine_train", default=20)  # 微调阶段

parser.add_argument("--hidden_dim", default=256)
parser.add_argument("--output_dim", default=64)

parser.add_argument("--neighbor_init", default=5)
parser.add_argument("--neighbor_incr", default=1)
parser.add_argument("--neighbor_max", default=15)

parser.add_argument("--lambda1", default=1)
parser.add_argument("--lambda2", default=1)
parser.add_argument("--lambda3", default=0)
parser.add_argument("--lambda4", default=1)

parser.add_argument('--gpu', default='0', type=str, help='GPU device idx.')  # 指定使用的GPU设备ID
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if __name__ == "__main__":
    if args.dataset == "RGB-D":
        args.seed = 5
    if args.dataset == "BDGP":
        args.seed = 5
    if  args.dataset == "Caltech-3V":
        args.seed = 5
    if  args.dataset == "Caltech-4V":
        args.seed = 5
    if  args.dataset == "Caltech-5V":
        args.seed = 5
    if  args.dataset == "BBCSport":
        args.seed = 5
    if args.dataset == "MNIST_USPS":
        args.seed = 5
    if args.dataset == "Caltech101-7":
        args.seed = 5

    if args.dataset == "handwritten":
        args.seed = 5
    if args.dataset == "COIL20":
        args.seed = 5
    if args.dataset == "HW1256":
        args.seed = 5
    if args.dataset == "MNIST_USPS":
        args.seed = 5

    print("==================================\nArgs:{}\n==================================".format(args))
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    neighbor_num = args.neighbor_init
    mv_data = MultiviewData(args.dataset, device)
    num_views = mv_data.num_views
    num_samples = mv_data.labels.size
    num_clusters = np.unique(mv_data.labels).size
    input_sizes = np.zeros(num_views, dtype=int)
    for idx in range(num_views):
        input_sizes[idx] = mv_data.data_views[idx].shape[1]

    weights_mv, raw_weights_mv, laplacian_mv = update_graph(mv_data.data_views, neighbor_num)
    network = AdaGAEMV(mv_data.data_views, [args.hidden_dim, args.output_dim], device)
    optimizer = torch.optim.Adam(
        chain(*[network.gae_list[v].parameters() for v in range(num_views)]),
        lr=args.learning_rate,
    )

    print("start pre_training...")
    for epoch in range(args.pre_train):
        for i in range(args.pretrain_iter):
            embedding_list, recons_w_list, fused_embedding, norm_weights = network.forward(mv_data.data_views, laplacian_mv)
            total_loss, re_loss, tr_loss, mmd_loss, cluster_loss = network.cal_loss(
                raw_weights_mv,
                recons_w_list,
                weights_mv,
                embedding_list,
                fused_embedding,
                args.lambda1,
                args.lambda2,
                args.lambda3,
                num_clusters
            )
            loss = re_loss + args.lambda1 * tr_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Epoch {}, iter {}, Loss {}'.format(epoch+1, i+1, loss.item()))
        weights_mv, raw_weights_mv, laplacian_mv = update_graph(
            embedding_list, neighbor_num
        )
        neighbor_num = min(neighbor_num + args.neighbor_incr, args.neighbor_max)

    print("start fine_training...")
    mse_loss_func = nn.MSELoss()
    acc_list, ami_list, ari_list = [], [], []  # 用于记录每个epoch的指标
    for epoch in range(args.fine_train):
        embedding_list, recons_w_list, fused_embedding, norm_weights = network.forward(mv_data.data_views, laplacian_mv)
        total_loss, re_loss, tr_loss, mmd_loss, cluster_loss = network.cal_loss(
            raw_weights_mv,
            recons_w_list,
            weights_mv,
            embedding_list,
            fused_embedding,
            args.lambda1,
            args.lambda2,
            args.lambda3,
            num_clusters
        )
        con_loss = 0
        for vi in range(num_views):
            for vj in range(vi + 1, num_views):
                con_loss += mse_loss_func(embedding_list[vi], embedding_list[vj])
        loss = total_loss + args.lambda4 * con_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Epoch {}, Loss {}'.format(epoch + 1, loss.item()))

        # 获取当前epoch的聚类结果和指标
        pred_labels, metrics = get_cluster_labels_and_metrics(fused_embedding, mv_data.labels, num_clusters, device)
        acc_list.append(metrics['ACC'])
        ami_list.append(metrics['NMI'])
        ari_list.append(metrics['ARI'])

    # 绘制折线图
    epochs = range(1, args.fine_train + 1)
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, acc_list, label='ACC', marker='o')
    plt.plot(epochs, ami_list, label='NMI', marker='s')
    plt.plot(epochs, ari_list, label='ARI', marker='^')
    plt.title('Metrics during fine_training')
    plt.xlabel('Epoch')
    plt.ylabel('Metrics')
    plt.legend()
    plt.grid(True)
    plt.show()
    # 在 plt.show() 之后添加
    best_idx = np.argmax(acc_list)
    best_acc = acc_list[best_idx]
    best_nmi = ami_list[best_idx]  # 实际上是 NMI
    best_ari = ari_list[best_idx]

    print(f"Best ACC: {best_acc:.4f}  (Epoch {best_idx + 1})")
    print(f"Corresponding NMI: {best_nmi:.4f}")
    print(f"Corresponding ARI: {best_ari:.4f}")