import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import copy
from tqdm import tqdm
import xgboost as xgb
import os
import platform
import wandb 
import argparse 

from sklearn.preprocessing import StandardScaler
from federated_dataset import create_federated_loaders
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

DATA_PATH_PREFIX = 'processed'
OUTPUT_DIR = 'final_industrial_results'

NUM_ROUNDS = 100
NUM_LOCAL_EPOCHS = 1
BATCH_SIZE = 64
LEARNING_RATE = 1e-5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

BETA_1 = 0.9
BETA_2 = 0.999
EPSILON = 1e-8

EQUIPMENT_MAP = {
    'ACSR-OC': 0, 'CNCV-W': 1, 'TFR-CV': 2,
    '계기용변압기': 3, '단상유입변압기': 4, '전력용유입변압기': 5,
    '7.2kV배전반': 6, '22.9kV배전반': 7, '25.8kVGIS': 8
}
class Simple1DCNN(nn.Module):
    def __init__(self, input_features=12, num_classes=1):
        super(Simple1DCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_features, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(64)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 4, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.4)

    def forward(self, x, dif_x=None):
        combined_x = torch.cat((x, dif_x), dim=2)
        x = combined_x.permute(0, 2, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.epsilon = 1e-8
    
    def forward(self, inputs, targets):
        log_probs = F.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)

        one_hot_targets = F.one_hot(targets, num_classes=inputs.size(1)).float()
        
        BCE_loss = - (one_hot_targets * torch.log(probs.clamp(min=self.epsilon)) + \
                      (1 - one_hot_targets) * torch.log((1 - probs).clamp(min=self.epsilon)))
        
        pt = probs.gather(1, targets.view(-1, 1)).squeeze()
        
        focal_weight = self.alpha * (1 - pt).pow(self.gamma)
        F_loss = focal_weight.view(-1, 1) * BCE_loss
        
        if self.reduction == 'mean': return F_loss.mean()
        elif self.reduction == 'sum': return F_loss.sum()
        else: return F_loss

def evaluate_committee_model(anchor_models, data_loader, device):
    for model in anchor_models: model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch_data, batch_diff_data, batch_multi_class_labels in data_loader:
            batch_data, batch_diff_data = batch_data.to(device), batch_diff_data.to(device)
            all_anchor_probs = [torch.sigmoid(model(batch_data, batch_diff_data)).squeeze() for model in anchor_models]
            stacked_probs = torch.stack(all_anchor_probs, dim=1)
            _, final_preds = torch.max(stacked_probs, 1)
            all_preds.extend(final_preds.cpu().numpy())
            all_labels.extend(batch_multi_class_labels.numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    for model in anchor_models: model.train()
    return accuracy, f1

def plot_learning_curves(history, output_dir, name=""):
    plt.figure(figsize=(12, 6))
    rounds = range(1, len(history['acc']) + 1)
    plt.plot(rounds, history['acc'], 'o-', label='Validation Accuracy', markersize=4)
    plt.plot(rounds, history['f1'], 's--', label='Validation F1-score', markersize=4)
    plt.plot(rounds, history['loss'], '^-', label='Avg Train Loss', markersize=4, alpha=0.6)
    plt.title(f'Learning Curves ({name})'); plt.xlabel('Rounds'); plt.ylabel('Value')
    plt.legend(); plt.grid(True); plt.tight_layout()
    save_path = os.path.join(output_dir, f'learning_curve_{name}.png')
    plt.savefig(save_path)
    print(f"\nLearning curve plot saved to {save_path}")
    wandb.log({f'learning_curve_{name}': wandb.Image(save_path)})
    plt.close()

def full_analysis_report(anchor_models, val_loader, test_loader, device, class_names, output_dir, per_edge_f1_report, name=""):
    print("\n--- Starting Full Performance Analysis ---")
    for model in anchor_models: model.eval()
    def generate_meta_features(data_loader, desc_text):
        all_anchor_probs, all_true_labels = [], []
        with torch.no_grad():
            for batch_data, batch_diff_data, batch_labels in tqdm(data_loader, desc=desc_text):
                batch_data, batch_diff_data = batch_data.to(device), batch_diff_data.to(device)
                anchor_probs_per_batch = [torch.sigmoid(model(batch_data, batch_diff_data)).squeeze() for model in anchor_models]
                stacked_probs = torch.stack(anchor_probs_per_batch, dim=1)
                all_anchor_probs.append(stacked_probs.cpu().numpy())
                all_true_labels.append(batch_labels.numpy())
        return np.concatenate(all_anchor_probs, axis=0), np.concatenate(all_true_labels, axis=0)
    meta_features_val, true_labels_val = generate_meta_features(val_loader, "  Generating val meta-features")
    meta_features_test, true_labels_test = generate_meta_features(test_loader, "  Generating test meta-features")
    soft_voting_preds = np.argmax(meta_features_test, axis=1)
    report_soft = classification_report(true_labels_test, soft_voting_preds, target_names=class_names, zero_division=0)
    stacking_model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', random_state=42)
    stacking_model.fit(meta_features_val, true_labels_val)
    stacking_preds = stacking_model.predict(meta_features_test)
    report_stack = classification_report(true_labels_test, stacking_preds, target_names=class_names, zero_division=0)
    report_path = os.path.join(output_dir, f'final_report_{name}.txt')
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(f"--- FINAL ANALYSIS REPORT ({name}) ---\n\n")
        f.write(f"--- Soft-Voting ---\n{report_soft}\n\n--- Stacking Ensemble ---\n{report_stack}\n\n")
        f.write(f"--- F1-Score per Edge ---\n{per_edge_f1_report}")
    print(f"\nClassification reports saved to: {report_path}")
    wandb.save(report_path)
    cm = confusion_matrix(true_labels_test, stacking_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix (Stacking) - {name}'); plt.xlabel('Predicted'); plt.ylabel('True')
    cm_path = os.path.join(output_dir, f'confusion_matrix_stacking_{name}.png')
    plt.savefig(cm_path); plt.close()
    print(f"Confusion matrix saved to {cm_path}")
    wandb.log({f"confusion_matrix_{name}": wandb.Image(cm_path)})

def analyze_per_edge_performance(anchor_models, X_test_main, X_test_diff, y_test, edge_ids_test, device, equipment_map):
    print("\n--- F1-Score Analysis per Edge ---")
    edge_names = list(equipment_map.keys())
    report_string = ""
    for edge_id in range(len(edge_names)):
        indices = np.where(edge_ids_test == edge_id)[0]
        if len(indices) == 0: continue
        edge_dataset = TensorDataset(torch.from_numpy(X_test_main[indices]).float(), torch.from_numpy(X_test_diff[indices]).float(), torch.from_numpy(y_test[indices]).long())
        edge_loader = DataLoader(edge_dataset, batch_size=BATCH_SIZE * 2)
        _, f1 = evaluate_committee_model(anchor_models, edge_loader, device)
        result_line = f"  Edge {edge_id} ({edge_names[edge_id]}): \t F1-Score = {f1:.4f}\n"
        print(result_line, end=''); report_string += result_line
        wandb.log({f"edge_{edge_id}_f1": f1})
    return report_string


def run_experiment(exp_name, algorithm, param_value, fedavg_enabled, fedprox_enabled, scaffold_enabled, fedadam_enabled, fedyogi_enabled, fednova_enabled):
    wandb.init(project="Industrial_FL_Analysis", name=exp_name, reinit=True)
    
    SERVER_LEARNING_RATE = param_value if algorithm in ['fedadam', 'fedyogi'] else 0.005
    PROXIMAL_MU = param_value if algorithm == 'fedprox' else 0.0

    wandb.config.update({
        "algorithm": algorithm,
        "param_value": param_value,
        "num_rounds": NUM_ROUNDS,
        "local_epochs": NUM_LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE,
        "client_learning_rate": LEARNING_RATE,
        "server_learning_rate": SERVER_LEARNING_RATE,
        "beta_1": BETA_1,
        "beta_2": BETA_2,
        "epsilon": EPSILON,
        "proximal_mu": PROXIMAL_MU
    })

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print("\n--- Loading & Preprocessing Data ---")
    X_train_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'X_train.npy'))
    y_train_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'y_train.npy'))
    edge_ids_train_orig = np.load(os.path.join(DATA_PATH_PREFIX, 'edge_ids_train.npy'))
    X_test = np.load(os.path.join(DATA_PATH_PREFIX, 'X_val.npy'))
    y_test = np.load(os.path.join(DATA_PATH_PREFIX, 'y_val.npy'))
    edge_ids_test = np.load(os.path.join(DATA_PATH_PREFIX, 'edge_ids_val.npy'))
    scaler = StandardScaler()
    X_train_orig_scaled = scaler.fit_transform(X_train_orig.reshape(-1, X_train_orig.shape[2])).reshape(X_train_orig.shape)
    X_test_scaled = scaler.transform(X_test.reshape(-1, X_test.shape[2])).reshape(X_test.shape)
    X_train, X_val, y_train, y_val, edge_ids_train, edge_ids_val = train_test_split(
        X_train_orig_scaled, y_train_orig, edge_ids_train_orig, test_size=0.2, random_state=42, stratify=y_train_orig)
    X_train_main, X_train_diff = X_train[:, :-1, :], np.diff(X_train, axis=1)
    train_edge_loaders = create_federated_loaders(X_train_main, X_train_diff, y_train, edge_ids_train, BATCH_SIZE)
    X_val_main, X_val_diff = X_val[:, :-1, :], np.diff(X_val, axis=1)
    val_dataset = TensorDataset(torch.from_numpy(X_val_main).float(), torch.from_numpy(X_val_diff).float(), torch.from_numpy(y_val).long())
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False)
    X_test_main, X_test_diff = X_test_scaled[:, :-1, :], np.diff(X_test_scaled, axis=1)
    test_dataset = TensorDataset(torch.from_numpy(X_test_main).float(), torch.from_numpy(X_test_diff).float(), torch.from_numpy(y_test).long())
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False)
    
    anchor_models = [Simple1DCNN().to(DEVICE) for _ in range(5)]
    criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='mean')
    history = {'acc': [], 'f1': [], 'loss': []}

    server_m = [{k: torch.zeros_like(v, device=DEVICE) for k, v in model.state_dict().items() if v.dtype.is_floating_point} for model in anchor_models]
    server_v = [{k: torch.zeros_like(v, device=DEVICE) for k, v in model.state_dict().items() if v.dtype.is_floating_point} for model in anchor_models]

    print(f"\n--- Starting {algorithm.upper()} Federated Learning Training on {DEVICE} ---")
    for r in range(NUM_ROUNDS):
        round_training_losses = []
        all_model_deltas = [[] for _ in range(5)]
        
        for edge_id, edge_loader in enumerate(tqdm(train_edge_loaders, desc=f"Round {r+1}/{NUM_ROUNDS}")):
            if edge_loader is None: continue
            
            local_anchor_models = [copy.deepcopy(model) for model in anchor_models]
            
            for anchor_id in range(5):
                model = local_anchor_models[anchor_id]
                optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
                
                global_model_state = anchor_models[anchor_id].state_dict()
                
                local_steps = 0
                for local_epoch in range(NUM_LOCAL_EPOCHS):
                    for batch_data, batch_diff, multi_labels in edge_loader:
                        batch_data, batch_diff, multi_labels = batch_data.to(DEVICE), batch_diff.to(DEVICE), multi_labels.to(DEVICE)
                        binary_labels = (multi_labels == anchor_id).float()
                        
                        optimizer.zero_grad()
                        outputs = model(batch_data, batch_diff).squeeze()
                        loss = criterion(outputs, binary_labels)
                        
                        # FedProx 항 추가
                        if fedprox_enabled:
                            proximal_term = 0.0
                            for w, w_glob in zip(model.parameters(), anchor_models[anchor_id].parameters()):
                                proximal_term += torch.sum(torch.pow(w - w_glob, 2))
                            loss += (PROXIMAL_MU / 2) * proximal_term

                        if not torch.isnan(loss):
                            loss.backward()
                            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                            optimizer.step()
                            round_training_losses.append(loss.item())
                            local_steps += 1
                
                if local_steps > 0:
                    local_model_state_after = model.state_dict()
                    float_keys = [k for k, v in global_model_state.items() if v.dtype.is_floating_point]
                    model_delta = {k: local_model_state_after[k] - global_model_state[k] for k in float_keys}
                    all_model_deltas[anchor_id].append(model_delta)
        
        for anchor_idx in range(5):
            if not all_model_deltas[anchor_idx]: continue
            
            avg_delta = {k: torch.stack([d[k] for d in all_model_deltas[anchor_idx]]).mean(0)
                         for k in all_model_deltas[anchor_idx][0]}
            
            global_model = anchor_models[anchor_idx]
            
            if fedavg_enabled or fedprox_enabled or scaffold_enabled or fednova_enabled:
                with torch.no_grad():
                    for k in avg_delta.keys():
                        global_model.state_dict()[k].add_(avg_delta[k])
            
            elif fedadam_enabled or fedyogi_enabled:
                m = server_m[anchor_idx]
                v = server_v[anchor_idx]
                t = r + 1 
                with torch.no_grad():
                    for k in avg_delta.keys():
                        m[k] = BETA_1 * m[k] + (1 - BETA_1) * avg_delta[k]
                        if fedyogi_enabled:
                            g_t_squared = torch.pow(avg_delta[k], 2)
                            v[k].addcmul_(torch.sign(v[k] - g_t_squared), g_t_squared, value=-(1 - BETA_2))
                        elif fedadam_enabled:
                            v[k] = BETA_2 * v[k] + (1 - BETA_2) * torch.pow(avg_delta[k], 2)
                            
                        m_hat = m[k] / (1 - BETA_1**t)
                        v_hat = v[k] / (1 - BETA_2**t)
                        
                        update_val = SERVER_LEARNING_RATE * m_hat / (torch.sqrt(v_hat) + EPSILON)
                        global_model.state_dict()[k].add_(update_val)

        val_acc, val_f1 = evaluate_committee_model(anchor_models, val_loader, DEVICE)
        avg_loss = np.mean(round_training_losses) if round_training_losses else 0
        history['acc'].append(val_acc); history['f1'].append(val_f1); history['loss'].append(avg_loss)
        
        print(f"\n--- Round {r+1} Summary ---")
        print(f"  Avg Train Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")
        wandb.log({"round": r + 1, "train_loss": avg_loss, "val_accuracy": val_acc, "val_f1": val_f1})

    print("\nFederated Learning training finished!")
    
    class_names_en = ['Normal', 'Noise', 'Surface', 'Corona', 'Void']
    plot_learning_curves(history, OUTPUT_DIR, name=exp_name)
    per_edge_report = analyze_per_edge_performance(anchor_models, X_test_main, X_test_diff, y_test, edge_ids_test, DEVICE, EQUIPMENT_MAP)
    full_analysis_report(anchor_models, val_loader, test_loader, DEVICE, class_names_en, OUTPUT_DIR, per_edge_report, name=exp_name)
    wandb.finish()


def main():
    parser = argparse.ArgumentParser(description="Run various Federated Learning algorithms.")
    parser.add_argument('--algo', type=str, required=True, 
                        choices=['fedadam', 'fedavg', 'fednova', 'fedprox', 'fedyogi', 'scaffold'],
                        help="Choose the FL algorithm to run.")
    parser.add_argument('--param', type=float, default=0.01,
                        help="Algorithm-specific parameter (e.g., server_lr or proximal_mu).")
    args = parser.parse_args()
    
    exp_name = f"{args.algo}_{args.param}"
    
    if platform.system() == 'Windows':
        plt.rc('font', family='Malgun Gothic')
    elif platform.system() == 'Darwin': # MacOS
        plt.rc('font', family='AppleGothic')
    else: # Linux
        try:
            plt.rc('font', family='NanumGothic')
        except:
            print("NanumGothic font not found. Please install it (`sudo apt-get install fonts-nanum*`)")
    plt.rcParams['axes.unicode_minus'] = False
    
    fedavg_enabled = (args.algo == 'fedavg')
    scaffold_enabled = (args.algo == 'scaffold')
    fednova_enabled = (args.algo == 'fednova')
    fedprox_enabled = (args.algo == 'fedprox')
    fedadam_enabled = (args.algo == 'fedadam')
    fedyogi_enabled = (args.algo == 'fedyogi')
    
    run_experiment(exp_name=exp_name, algorithm=args.algo, param_value=args.param,
                   fedavg_enabled=fedavg_enabled, scaffold_enabled=scaffold_enabled,
                   fednova_enabled=fednova_enabled, fedprox_enabled=fedprox_enabled,
                   fedadam_enabled=fedadam_enabled, fedyogi_enabled=fedyogi_enabled)

if __name__ == '__main__':
    main()