import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import copy
from tqdm import tqdm
import xgboost as xgb
import os
import platform
import wandb 
import argparse 
from sklearn.ensemble import RandomForestClassifier
from catboost import CatBoostClassifier
from sklearn.linear_model import LogisticRegression
import lightgbm as lgb
from sklearn.preprocessing import StandardScaler
from federated_dataset import create_federated_loaders
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import time

warnings.filterwarnings("ignore", category=FutureWarning)

DATA_PATH_PREFIX = 'processed'
OUTPUT_DIR = 'final_industrial_results'
NUM_ROUNDS = 100
NUM_LOCAL_EPOCHS = 1
BATCH_SIZE = 64
LEARNING_RATE = 1e-5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MU = 0.01

CLASS_NAMES = ['ACSR-OC', 'CNCV-W', 'TFR-CV', '계기용변압기', '단상유입변압기', '전력용유입변압기', '7.2kV배전반', '22.9kV배전반', '25.8kVGIS']
NUM_CLASSES = len(CLASS_NAMES)
EQUIPMENT_MAP = {name: i for i, name in enumerate(CLASS_NAMES)}

class Simple1DCNN(nn.Module):
    def __init__(self, input_features=12, num_classes=9):
        super(Simple1DCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_features, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(64)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 4, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.4)
    def forward(self, x, dif_x=None):
        combined_x = torch.cat((x, dif_x), dim=2)
        x = combined_x.permute(0, 2, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.epsilon = 1e-8
    
    def forward(self, inputs, targets):
        log_probs = F.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)

        one_hot_targets = F.one_hot(targets, num_classes=inputs.size(1)).float()
        
        BCE_loss = - (one_hot_targets * torch.log(probs.clamp(min=self.epsilon)) + \
                      (1 - one_hot_targets) * torch.log((1 - probs).clamp(min=self.epsilon)))
        
        pt = probs.gather(1, targets.view(-1, 1)).squeeze()
        
        focal_weight = self.alpha * (1 - pt).pow(self.gamma)
        F_loss = focal_weight.view(-1, 1) * BCE_loss
        
        if self.reduction == 'mean': return F_loss.mean()
        elif self.reduction == 'sum': return F_loss.sum()
        else: return F_loss


def client_update_fedprox(local_model, global_model, loader, criterion, lr, local_epochs, mu):
    local_model.train()
    optimizer = optim.Adam(local_model.parameters(), lr=lr)
    epoch_losses = []
    for _ in range(local_epochs):
        for data, diff_data, labels in loader:
            data, diff_data, labels = data.to(DEVICE), diff_data.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = local_model(data, diff_data)
            loss = criterion(outputs, labels)
            proximal_term = 0.0
            for local_param, global_param in zip(local_model.parameters(), global_model.parameters()):
                proximal_term += (local_param - global_param).pow(2).sum()
            loss += (mu / 2) * proximal_term
            if not torch.isnan(loss):
                loss.backward()
                torch.nn.utils.clip_grad_norm_(local_model.parameters(), max_norm=1.0)
                optimizer.step()
                epoch_losses.append(loss.item())
    return local_model.state_dict(), np.mean(epoch_losses) if epoch_losses else 0

def evaluate_committee_model(anchor_models, data_loader, device):
    for model in anchor_models: model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch_data, batch_diff_data, batch_labels in data_loader:
            batch_data, batch_diff_data = batch_data.to(device), batch_diff_data.to(device)
            all_anchor_outputs = [model(batch_data, batch_diff_data) for model in anchor_models]
            stacked_outputs = torch.stack(all_anchor_outputs, dim=1)
            _, final_preds = torch.max(torch.mean(stacked_outputs, dim=1), 1)
            all_preds.extend(final_preds.cpu().numpy())
            all_labels.extend(batch_labels.numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    for model in anchor_models: model.train()
    return accuracy, f1

def plot_learning_curves(history, output_dir, name=""):
    plt.figure(figsize=(12, 6))
    rounds = range(1, len(history['acc']) + 1)
    plt.plot(rounds, history['acc'], 'o-', label='Validation Accuracy', markersize=4)
    plt.plot(rounds, history['f1'], 's--', label='Validation F1-score', markersize=4)
    plt.plot(rounds, history['loss'], '^-', label='Avg Train Loss', markersize=4, alpha=0.6)
    plt.title(f'Learning Curves ({name})'); plt.xlabel('Rounds'); plt.ylabel('Value')
    plt.legend(); plt.grid(True); plt.tight_layout()
    save_path = os.path.join(output_dir, f'learning_curve_{name}.png')
    plt.savefig(save_path)
    print(f"\nLearning curve plot saved to {save_path}")
    wandb.log({f'learning_curve_{name}': wandb.Image(save_path)})
    plt.close()

def full_analysis_report(anchor_models, val_loader, test_loader, device, class_names, output_dir, per_edge_f1_report, name="", stacking_model_name="xgboost"):
    print("\n--- Starting Full Performance Analysis ---")
    for model in anchor_models: model.eval()
    def generate_meta_features(data_loader, desc_text):
        all_anchor_probs, all_true_labels = [], []
        with torch.no_grad():
            for batch_data, batch_diff_data, batch_labels in tqdm(data_loader, desc=desc_text):
                batch_data, batch_diff_data = batch_data.to(device), batch_diff_data.to(device)
                anchor_probs_per_batch = [F.softmax(model(batch_data, batch_diff_data), dim=1) for model in anchor_models]
                stacked_probs = torch.stack(anchor_probs_per_batch, dim=1)
                all_anchor_probs.append(stacked_probs.cpu().numpy())
                all_true_labels.append(batch_labels.numpy())
        return np.concatenate(all_anchor_probs, axis=0), np.concatenate(all_true_labels, axis=0)
    meta_features_val, true_labels_val = generate_meta_features(val_loader, "  Generating val meta-features")
    meta_features_test, true_labels_test = generate_meta_features(test_loader, "  Generating test meta-features")
    soft_voting_preds = np.argmax(np.mean(meta_features_test, axis=1), axis=1)
    report_soft = classification_report(true_labels_test, soft_voting_preds, target_names=class_names, zero_division=0)
    
    if stacking_model_name == "xgboost":
        stacking_model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', random_state=42)
    elif stacking_model_name == "randomforest":
        stacking_model = RandomForestClassifier(random_state=42, n_jobs=-1)
    elif stacking_model_name == "lightgbm":
        stacking_model = lgb.LGBMClassifier(random_state=42, n_jobs=-1)
    elif stacking_model_name == "catboost":
        stacking_model = CatBoostClassifier(random_state=42, verbose=0, thread_count=-1)
    elif stacking_model_name == "logisticregression":
        stacking_model = LogisticRegression(max_iter=1000, random_state=42, n_jobs=-1)
    else:
        raise ValueError(f"Unknown stacking model: {stacking_model_name}")

    stacking_model.fit(meta_features_val.reshape(meta_features_val.shape[0], -1), true_labels_val)
    stacking_preds = stacking_model.predict(meta_features_test.reshape(meta_features_test.shape[0], -1))
    report_stack = classification_report(true_labels_test, stacking_preds, target_names=class_names, zero_division=0)
    report_path = os.path.join(output_dir, f'final_report_{name}_{stacking_model_name}.txt')
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(f"--- FINAL ANALYSIS REPORT ({name} with {stacking_model_name}) ---\n\n")
        f.write(f"--- Soft-Voting ---\n{report_soft}\n\n--- Stacking Ensemble ---\n{report_stack}\n\n")
        f.write(f"--- F1-Score per Edge ---\n{per_edge_f1_report}")
    print(f"\nClassification reports saved to: {report_path}")
    wandb.save(report_path)
    cm = confusion_matrix(true_labels_test, stacking_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix (Stacking) - {name}'); plt.xlabel('Predicted'); plt.ylabel('True')
    cm_path = os.path.join(output_dir, f'confusion_matrix_stacking_{name}.png')
    plt.savefig(cm_path); plt.close()
    print(f"Confusion matrix saved to {cm_path}")
    wandb.log({f"confusion_matrix_{name}": wandb.Image(cm_path)})

def analyze_per_edge_performance(anchor_models, X_test_main, X_test_diff, y_test, edge_ids_test, device, equipment_map):
    print("\n--- F1-Score Analysis per Edge ---")
    edge_names = list(equipment_map.keys())
    report_string = ""
    for edge_id in range(len(edge_names)):
        indices = np.where(edge_ids_test == edge_id)[0]
        if len(indices) == 0: continue
        edge_dataset = TensorDataset(torch.from_numpy(X_test_main[indices]).float(), torch.from_numpy(X_test_diff[indices]).float(), torch.from_numpy(y_test[indices]).long())
        edge_loader = DataLoader(edge_dataset, batch_size=BATCH_SIZE * 2)
        _, f1 = evaluate_committee_model(anchor_models, edge_loader, device)
        result_line = f"  Edge {edge_id} ({edge_names[edge_id]}): \t F1-Score = {f1:.4f}\n"
        print(result_line, end=''); report_string += result_line
        wandb.log({f"edge_{edge_id}_f1": f1})
    return report_string

def run_experiment(mu, stacking_model_name):
    exp_name = f"FedProx_mu{mu}_stacking_{stacking_model_name}"
    wandb.init(project="Industrial_FL_Analysis", name=exp_name, reinit=True)
    wandb.config.update({
        "algorithm": "FedProx", "mu": mu,
        "num_rounds": NUM_ROUNDS, "local_epochs": NUM_LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE, "client_learning_rate": LEARNING_RATE,
        "stacking_model": stacking_model_name
    })

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    print("\n--- Loading & Preprocessing Data ---")
    X_train_orig=np.load(os.path.join(DATA_PATH_PREFIX,'X_train.npy')); y_train_orig=np.load(os.path.join(DATA_PATH_PREFIX,'y_train.npy'))
    edge_ids_train_orig=np.load(os.path.join(DATA_PATH_PREFIX,'edge_ids_train.npy')); X_test=np.load(os.path.join(DATA_PATH_PREFIX,'X_val.npy'))
    y_test=np.load(os.path.join(DATA_PATH_PREFIX,'y_val.npy')); edge_ids_test=np.load(os.path.join(DATA_PATH_PREFIX,'edge_ids_val.npy'))
    scaler=StandardScaler(); X_train_orig_scaled=scaler.fit_transform(X_train_orig.reshape(-1,X_train_orig.shape[2])).reshape(X_train_orig.shape)
    X_test_scaled=scaler.transform(X_test.reshape(-1,X_test.shape[2])).reshape(X_test.shape)
    X_train,X_val,y_train,y_val,edge_ids_train,edge_ids_val=train_test_split(X_train_orig_scaled,y_train_orig,edge_ids_train_orig,test_size=0.2,random_state=42,stratify=y_train_orig)
    X_train_main, X_train_diff = X_train[:, :-1, :], np.diff(X_train, axis=1)
    train_edge_loaders = create_federated_loaders(X_train_main, X_train_diff, y_train, edge_ids_train, BATCH_SIZE)
    client_sizes = {i: len(loader.dataset) for i, loader in enumerate(train_edge_loaders) if loader is not None}
    X_val_main, X_val_diff = X_val[:, :-1, :], np.diff(X_val, axis=1)
    val_loader = DataLoader(TensorDataset(torch.from_numpy(X_val_main).float(), torch.from_numpy(X_val_diff).float(), torch.from_numpy(y_val).long()), batch_size=BATCH_SIZE * 2, shuffle=False)
    X_test_main, X_test_diff = X_test_scaled[:, :-1, :], np.diff(X_test_scaled, axis=1)
    test_loader = DataLoader(TensorDataset(torch.from_numpy(X_test_main).float(), torch.from_numpy(X_test_diff).float(), torch.from_numpy(y_test).long()), batch_size=BATCH_SIZE * 2, shuffle=False)
    
    anchor_models = [Simple1DCNN().to(DEVICE) for _ in range(5)]
    criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='mean')
    history = {'acc': [], 'f1': [], 'loss': []}
    global_anchor_models = [copy.deepcopy(m) for m in anchor_models]

    print(f"\n--- Starting FedProx Federated Learning Training on {DEVICE} ---")
    for r in range(NUM_ROUNDS):
        round_training_losses = []
        all_model_states = [[] for _ in range(5)]

        for edge_id, edge_loader in enumerate(tqdm(train_edge_loaders, desc=f"Round {r+1}/{NUM_ROUNDS}")):
            if edge_loader is None: continue

            for anchor_id in range(5):
                local_model = copy.deepcopy(global_anchor_models[anchor_id])
                local_model.anchor_id = anchor_id
                updated_state, avg_loss = client_update_fedprox(local_model, global_anchor_models[anchor_id], edge_loader, criterion, LEARNING_RATE, NUM_LOCAL_EPOCHS, mu)
                all_model_states[anchor_id].append(updated_state)
                if avg_loss > 0: round_training_losses.append(avg_loss)

        for anchor_id in range(5):
            if not all_model_states[anchor_id]: continue
            aggregated_state = {}
            for key in all_model_states[anchor_id][0].keys():
                temp_stack = torch.stack([state[key].float() for state in all_model_states[anchor_id]], dim=0)
                aggregated_state[key] = temp_stack.mean(0)
            global_anchor_models[anchor_id].load_state_dict(aggregated_state)

        val_acc, val_f1 = evaluate_committee_model(global_anchor_models, val_loader, DEVICE)
        avg_loss = np.mean(round_training_losses) if round_training_losses else 0
        history['acc'].append(val_acc)
        history['f1'].append(val_f1)
        history['loss'].append(avg_loss)
        print(f"\n--- Round {r+1} Summary ---")
        print(f"  Avg Train Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")
        wandb.log({"round": r+1, "train_loss": avg_loss, "val_accuracy": val_acc, "val_f1": val_f1})

    print("\nFederated Learning training finished!")

    class_names_en = ['Normal', 'Noise', 'Surface', 'Corona', 'Void']
    plot_learning_curves(history, OUTPUT_DIR, name=exp_name)
    per_edge_report = analyze_per_edge_performance(global_anchor_models, X_test_main, X_test_diff, y_test, edge_ids_test, DEVICE, EQUIPMENT_MAP)
    full_analysis_report(global_anchor_models, val_loader, test_loader, DEVICE, class_names_en, OUTPUT_DIR, per_edge_report, name=exp_name, stacking_model_name=stacking_model_name)
    wandb.finish()

def main():
    if platform.system() == 'Windows':
        plt.rc('font', family='Malgun Gothic')
    elif platform.system() == 'Darwin':
        plt.rc('font', family='AppleGothic')
    else:
        try:
            plt.rc('font', family='NanumGothic')
        except:
            print("NanumGothic font not found. Please install it (`sudo apt-get install fonts-nanum*`)")
    plt.rcParams['axes.unicode_minus'] = False
    
    parser = argparse.ArgumentParser(description="FedProx experiment with various stacking models.")
    parser.add_argument('--stacking_model', type=str, required=True,
                        choices=['xgboost', 'randomforest', 'lightgbm', 'catboost', 'logisticregression'],
                        help="Choose the stacking model for final analysis.")
    args = parser.parse_args()

    run_experiment(mu=MU, stacking_model_name=args.stacking_model)

if __name__ == '__main__':
    main()