import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# ==========================================
# 1. Configuration & Setup
# ==========================================
CONFIG = {
    'num_seeds': 5,            # Number of random seeds to average over
    'num_clients': 10,           # Number of clients to simulate
    'pretrain_epochs': 10,        # Epochs to train the feature extractor
    'rounds': 300,                # FL Rounds
    'local_epochs': 1,           # Local epochs per round
    'batch_size': 32,
    'lr_local': 0.002,            # Learning rate for local linear heads
    'lr_server': 0.002,            # Aggressive server learning rate for AffPCL
    'sanmples_per_client': 70,
    'warmup': 0,                # AffPCL warm-start rounds
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'leaf_root': './leaf/data/femnist'
}

def setup_leaf_data():
    """Downloads and preprocesses LEAF data if not present."""
    if os.path.exists(f"{CONFIG['leaf_root']}/data/train"):
        print("Data already exists. Skipping download.")
        return

    print("Downloading and preprocessing LEAF FEMNIST (5% sample)...")
    commands = [
        "git clone https://github.com/TalwalkarLab/leaf.git /content/leaf",
        "pip install -r /content/leaf/requirements.txt",
        f"cd {CONFIG['leaf_root']} && bash ./preprocess.sh -s niid --sf 0.05 -k 0 -t sample"
    ]
    for cmd in commands:
        os.system(cmd)
    print("Data setup complete.")

# ==========================================
# 2. Data Loading & Transforms
# ==========================================
class FEMNIST_LEAF(Dataset):
    def __init__(self, data_dir, user_id=None):
        self.x, self.y = [], []
        
        # Helper to load specific user or all data (for pre-training)
        files = [f for f in os.listdir(data_dir) if f.endswith('.json')]
        
        for f_name in files:
            with open(os.path.join(data_dir, f_name), 'r') as f:
                data = json.load(f)
                
            users_to_load = [user_id] if user_id else data['users']
            
            for u in users_to_load:
                if u in data['users']:
                    raw_x = data['user_data'][u]['x']
                    raw_y = data['user_data'][u]['y']
                    self.x.append(torch.tensor(raw_x).view(-1, 28, 28))
                    self.y.append(torch.tensor(raw_y).long())
            
            if user_id and len(self.x) > 0: break # Found specific user
            
        if len(self.x) == 0:
             return # Handle empty safely
             
        self.x = torch.cat(self.x).float()
        self.y = torch.cat(self.y).long()

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx].unsqueeze(0), self.y[idx]

class ContinuousTaskTransform:
    """Maps standard labels to continuous task targets b^i."""
    def __init__(self):
        self.DIGITS = set(range(0, 10))
        # Heuristic: 0, 3, 6, 8, 9, B, C, D, G, O, P, S, U, b, d, e, g, o, p, q
        self.CURVED = {0, 3, 6, 8, 9, 11, 12, 13, 16, 24, 25, 27, 28, 30, 37, 39, 40, 42, 50, 51, 52}

    def get_target(self, labels, lam):
        t_a = torch.tensor([1.0 if l.item() in self.DIGITS else -1.0 for l in labels])
        t_b = torch.tensor([1.0 if l.item() in self.CURVED else -1.0 for l in labels])
        return lam * t_a + (1.0 - lam) * t_b

    def get_central_target(self, labels):
        return self.get_target(labels, lam=0.5)

# ==========================================
# 3. Model & Pre-training
# ==========================================
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84) # Feature dim
        self.norm = nn.BatchNorm1d(84)  # Normalization layer before classifier
        self.classifier = nn.Linear(84, 62) # Temp head for pre-training

    def forward(self, x, return_features=False):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        features = F.relu(self.fc2(x))
        features = self.norm(features)  # Apply normalization before classifier
        if return_features:
            return features
        return self.classifier(features)


def train_feature_extractor(train_dir):
    print("\n--- Phase 1: Pre-training Feature Extractor on Real Data ---")
    # Load a subset of global data for pre-training
    files = [f for f in os.listdir(train_dir) if f.endswith('.json')][:2] # Use first 2 files (~100 users)
    
    datasets = []
    print("Loading pre-training data...")
    for f in files:
        # Load all users in this file
        with open(os.path.join(train_dir, f)) as json_file:
            u_ids = json.load(json_file)['users']
            for u in u_ids:
                datasets.append(FEMNIST_LEAF(train_dir, u))
    
    global_set = ConcatDataset(datasets)
    loader = DataLoader(global_set, batch_size=64, shuffle=True)
    
    model = SimpleCNN().to(CONFIG['device'])
    opt = optim.Adam(model.parameters(), lr=0.001)
    crit = nn.CrossEntropyLoss()
    
    model.train()
    for ep in range(CONFIG['pretrain_epochs']):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
            opt.zero_grad()
            out = model(x)
            loss = crit(out, y)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        print(f"Epoch {ep+1}/{CONFIG['pretrain_epochs']} Loss: {total_loss/len(loader):.4f}")
        
    model.eval()
    return model

# ==========================================
# 4. Unified Client
# ==========================================
class Client:
    def __init__(self, user_id, lam, train_ds, test_ds, phi, transformer):
        self.lam = lam
        self.transformer = transformer
        self.device = CONFIG['device']
        max_samples = CONFIG['sanmples_per_client']
        
        # Precompute features to treat this strictly as a Linear System
        # This speeds up the experiment significantly
        full_train_x, full_train_y = self._precompute(train_ds, phi)
        self.test_x, self.test_y = self._precompute(test_ds, phi)

        if max_samples is not None and max_samples < len(full_train_x):
            indices = torch.randperm(len(full_train_x))[:max_samples]
            self.train_x = full_train_x[indices]
            self.train_y = full_train_y[indices]
        else:
            self.train_x = full_train_x
            self.train_y = full_train_y
        
        # Personalized Linear Head (Initialized to 0)
        self.local_model = torch.zeros(84, 1).to(self.device)

    def _precompute(self, dataset, phi):
        loader = DataLoader(dataset, batch_size=256)
        feats, labels = [], []
        with torch.no_grad():
            for x, y in loader:
                f = phi(x.to(self.device), return_features=True)
                feats.append(f)
                labels.append(y)
        return torch.cat(feats), torch.cat(labels)

    def evaluate(self, model_weight=None):
        w = model_weight if model_weight is not None else self.local_model
        with torch.no_grad():
            preds = self.test_x @ w
            targets = self.transformer.get_target(self.test_y, self.lam).view(-1, 1).to(self.device)
            loss = F.mse_loss(preds, targets)
        return loss.item()

    # # --- INDEPENDENT LEARNING ---
    # def update_independent(self):
    #     # Simply train local model on local data
    #     self._train_linear_step(self.local_model, self.train_x, self.train_y, self.lam)
    #
    # # --- FEDAVG ---
    # def update_fedavg(self, global_weight):
    #     # Start from global weight
    #     temp_model = global_weight.clone()
    #     self._train_linear_step(temp_model, self.train_x, self.train_y, self.lam)
    #     return temp_model # Return updated weight

    def get_batch_grads(self, w_eval):
        """Computes gradient: X.T @ (Xw - b) / N"""
        b = self.transformer.get_target(self.train_y, self.lam).view(-1, 1).to(self.device)
        preds = self.train_x @ w_eval
        residuals = preds - b
        grad = (self.train_x.T @ residuals) / len(self.train_x)
        return grad

    # --- AFFPCL ---
    def get_affpcl_grads(self, global_xc, global_theta):
        # Full batch gradient for stability
        X, Y = self.train_x, self.train_y
        b_i = self.transformer.get_target(Y, self.lam).view(-1, 1).to(self.device)
        # b_c_hat = X @ global_theta
        b_c_hat = self.transformer.get_central_target(Y).view(-1, 1).to(self.device)
        
        N = len(X)
        
        # Gradients for Server Variables
        g_coe = (X.T @ (b_c_hat - b_i)) / N     # Update theta
        g_cdl = (X.T @ (X @ global_xc - b_i)) / N # Update xc
        
        # Bias Correction Terms
        g_bias_local = (X.T @ (X @ global_xc - b_c_hat)) / N
        g_loc = (X.T @ (X @ self.local_model - b_i)) / N
        
        return g_coe, g_cdl, g_bias_local, g_loc

    def update_affpcl(self, g_loc, g_bias_local, g_bias_global):
        # x <- x - lr * (g_loc + g_global - g_local_bias)
        grad = g_loc + g_bias_global - g_bias_local
        self.local_model -= CONFIG['lr_local'] * grad

    # def _train_linear_step(self, w, x, y, lam):
    #     # Helper for SGD steps
    #     w.requires_grad_(True)
    #     opt = optim.SGD([w], lr=CONFIG['lr_local'])
    #     dataset = torch.utils.data.TensorDataset(x, y)
    #     loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True)
    #     
    #     for _ in range(CONFIG['local_epochs']):
    #         for bx, by in loader:
    #             targets = self.transformer.get_target(by, lam).view(-1, 1).to(self.device)
    #             opt.zero_grad()
    #             pred = bx @ w
    #             loss = F.mse_loss(pred, targets)
    #             loss.backward()
    #             opt.step()
    #     w.requires_grad_(False)

# ==========================================
# 6. Updated Runners (Tracking Loss per Round)
# ==========================================
def run_independent(clients, rounds):
    # Reset
    for c in clients: c.local_model.zero_()
    
    history = []
    for r in range(rounds):
        for c in clients:
            # c.update_independent()
            g = c.get_batch_grads(c.local_model)
            c.local_model -= CONFIG['lr_local'] * g
        
        # Eval
        loss = np.mean([c.evaluate() for c in clients])
        history.append(loss)
    final_losses = [c.evaluate() for c in clients]
    return history, final_losses

def run_fedavg(clients, rounds):
    global_w = torch.zeros(84, 1).to(CONFIG['device'])
    history = []
    
    for r in range(rounds):
        # Broadcast & Update
        # local_weights = [c.update_fedavg(global_w) for c in clients]
        # # Aggregate
        # global_w = torch.stack(local_weights).mean(dim=0)
        local_weights = []
        for c in clients:
            g = c.get_batch_grads(global_w)
            temp_w = global_w - CONFIG['lr_local'] * g
            local_weights.append(temp_w)

        global_w = torch.stack(local_weights).mean(dim=0)
        
        # Eval (Test global model on personalized objectives)
        loss = np.mean([c.evaluate(global_w) for c in clients])
        history.append(loss)

    final_losses = [c.evaluate(global_w) for c in clients]
    return history, final_losses

def run_affpcl(clients, rounds):
    # Server State
    global_xc = torch.zeros(84, 1).to(CONFIG['device'])
    global_theta = torch.zeros(84, 1).to(CONFIG['device'])
    
    # Reset Clients
    for c in clients: c.local_model.zero_()
    
    history = []
    
    for r in range(rounds):
        # 1. Gather Gradients
        agg_coe = torch.zeros_like(global_theta)
        agg_cdl = torch.zeros_like(global_xc)
        agg_bias = torch.zeros_like(global_xc)
        client_cache = []
        
        for c in clients:
            g_coe, g_cdl, g_bias, g_loc = c.get_affpcl_grads(global_xc, global_theta)
            agg_coe += g_coe
            agg_cdl += g_cdl
            agg_bias += g_bias
            client_cache.append((g_bias, g_loc))
            
        # 2. Server Update (Fast Server LR to enable correction)
        N = len(clients)
        global_theta -= CONFIG['lr_server'] * (agg_coe / N)
        global_xc -= CONFIG['lr_server'] * (agg_cdl / N)
        g_bias_global = agg_bias / N
        
        # 3. Client Update (Warm Start)
        if r >= CONFIG['warmup']:
            for i, c in enumerate(clients):
                g_bias_local, g_loc = client_cache[i]
                c.update_affpcl(g_loc, g_bias_local, g_bias_global)
        
        # Eval
        loss = np.mean([c.evaluate() for c in clients])
        history.append(loss)
        
    final_losses = [c.evaluate() for c in clients]
    return history, final_losses

# ==========================================
# 7. Main Experiment Block
# ==========================================

# %% 0. Setup
setup_leaf_data()
train_dir = os.path.join(CONFIG['leaf_root'], 'data/train')
test_dir = os.path.join(CONFIG['leaf_root'], 'data/test')

phi = train_feature_extractor(train_dir)
transformer = ContinuousTaskTransform()

# Get pool of users
all_files = [f for f in os.listdir(train_dir) if f.endswith('.json')]
user_ids = []
for f in all_files:
    with open(os.path.join(train_dir, f)) as jf:
        user_ids.extend(json.load(jf)['users'])
    if len(user_ids) >= CONFIG['num_clients']: break
selected_users = user_ids[:CONFIG['num_clients']]

# Four Scenarios defining the spread of tasks (Lambdas)
scenarios = {
    "Homogeneous":   0.5 * np.ones(CONFIG['num_clients']),           # All 0.5
    "Low Heterogeneity":       np.linspace(0.4, 0.6, CONFIG['num_clients']),
    "Medium Heterogeneity":    np.linspace(0.2, 0.8, CONFIG['num_clients']),
    "High Heterogeneity":      np.linspace(0.0, 1.0, CONFIG['num_clients'])
}

# Result storage
results_convergence = {s: {'Indep': [], 'FedAvg': [], 'AffPCL': []} for s in scenarios}
results_final = {s: {'Indep': [], 'FedAvg': [], 'AffPCL': []} for s in scenarios}

# %% Run
for seed in range(CONFIG['num_seeds']):
    print(f"\n=== Random Seed: {seed} ===")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    for name, lambdas in scenarios.items():
        print(f"\n--- Scenario: {name} ---")
        
        # 1. Instantiate Clients for this scenario
        clients = [Client(uid, lam, FEMNIST_LEAF(train_dir, uid), FEMNIST_LEAF(test_dir, uid), phi, transformer) for uid, lam in zip(selected_users, lambdas)]
            
        # 2. Run Algorithms
        ROUNDS = CONFIG['rounds']
        hist_indep, final_indep = run_independent(clients, ROUNDS)
        hist_fedavg, final_fedavg = run_fedavg(clients, ROUNDS)
        hist_affpcl, final_affpcl = run_affpcl(clients, ROUNDS)
        
        # 3. Store Results
        results_convergence[name]['Indep'].append(hist_indep)
        results_convergence[name]['FedAvg'].append(hist_fedavg)
        results_convergence[name]['AffPCL'].append(hist_affpcl)
        
        results_final[name]['Indep'].append(final_indep)
        results_final[name]['FedAvg'].append(final_fedavg)
        results_final[name]['AffPCL'].append(final_affpcl)

# %% Plot

# Figure 1: convergence
fig1, axes1 = plt.subplots(1, 4, figsize=(12, 4))
# Figure 2: final test errors across agents
fig2, axes2 = plt.subplots(1, 4, figsize=(12, 4))

for idx, (name, lambdas) in enumerate(results_convergence.items()):
    for method, color, style in zip(
        ['Indep', 'FedAvg', 'AffPCL'],
        ['C0', 'C1', 'C2'],
        ['o', '^', 'D']
    ):
        ax = axes1[idx]
        # Average over seeds
        all_hist = np.array(lambdas[method])
        mean_hist = all_hist.mean(axis=0)
        ax.plot(mean_hist, label=method, color=color, marker=style, markersize=7, alpha=0.9, markevery=60, markerfacecolor='none', linewidth=2)
        ax.fill_between(range(len(mean_hist)),
                        mean_hist - 1.64 * all_hist.std(axis=0) / np.sqrt(CONFIG['num_seeds']),
                        mean_hist + 1.64 * all_hist.std(axis=0) / np.sqrt(CONFIG['num_seeds']),
                        color=color, alpha=0.2, label=None)
    
        ax.set_yscale('log')
        ax.grid(False)
        ax.tick_params(axis='both', which='both', length=0)
        ax.set_aspect(1./ax.get_data_ratio())
        ax.yaxis.set_minor_formatter(mticker.ScalarFormatter())
        if idx == 0:
            ax.set_title(name, fontsize=14)
        else:
            ax.set_title(name + "erogeneity", fontsize=14)

        ax = axes2[idx]
        final_errors = np.array(results_final[name][method])  # shape: (num_seeds, num_clients)
        mean_final = final_errors.mean(axis=0)
        ax.plot(range(1, CONFIG['num_clients'] + 1), mean_final, label=method, color=color, marker=style, markersize=7, alpha=0.9, linewidth=2)
        ax.fill_between(range(1, CONFIG['num_clients'] + 1),
                        mean_final - 1.64 * final_errors.std(axis=0) / np.sqrt(CONFIG['num_seeds']),
                        mean_final + 1.64 * final_errors.std(axis=0) / np.sqrt(CONFIG['num_seeds']),
                        color=color, alpha=0.2)
        ax.grid(False)
        ax.tick_params(axis='both', which='both', length=0)
        ax.set_aspect(1./ax.get_data_ratio())

fig1.supxlabel('# Rounds', fontsize=14, y=0.11)
fig1.supylabel('Test MSE', fontsize=14)
handles, labels = axes1[0].get_legend_handles_labels()
fig1.legend(handles, labels, loc='lower center', ncol=3, fontsize=14, frameon=False, bbox_to_anchor=(0.5, 0))
fig1.tight_layout(rect=[0, -0.03, 1, 0.94])
fig1.savefig('comp_real.png', dpi=300, bbox_inches='tight')

fig2.tight_layout(rect=[0, -0.03, 1, 0.94])
fig2.savefig('final_errors_heterogeneity_sweep.png')

plt.show()
print("Experiment complete. See heterogeneity_sweep.png")
