import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Sequence, Tuple, Optional, Dict
import os
import csv

np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
torch.set_default_dtype(torch.float64)

# Get device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def calculate_homogeneity_metrics(H: np.ndarray, y: np.ndarray = None) -> Dict[str, float]:

    N, K, D = H.shape
    bag_means = np.mean(H, axis=1)  # (N, D)
    
    in_bag_stds = []
    in_bag_rms = []
    in_bag_max_distances = []
    
    for i in range(N):
        bag_instances = H[i]  # (K, D)
        bag_mean = bag_means[i]  # (D,)
        
        std_per_dim = np.std(bag_instances, axis=0)  # (D,)
        in_bag_stds.append(np.mean(std_per_dim))
        
        deviations = bag_instances - bag_mean  # (K, D)
        squared_distances = np.sum(deviations**2, axis=1)  # (K,)
        rms_deviation = np.sqrt(np.mean(squared_distances))
        in_bag_rms.append(rms_deviation)
        
        max_distance = np.sqrt(np.max(squared_distances))
        in_bag_max_distances.append(max_distance)
    
    avg_in_bag_std = np.mean(in_bag_stds)
    avg_in_bag_rms = np.mean(in_bag_rms)
    
    metrics = {
        # In-bag dispersion
        'avg_in_bag_std': avg_in_bag_std,
        'avg_in_bag_rms': avg_in_bag_rms,
    }
    
    return metrics

def analyze_data_homogeneity(eta: float, nu: float, gamma_star: float,
                              N_train: int = 1000, N_val: int = 300,
                              D: int = 16, K: int = 10,
                              sigma_B: float = 1.0, sigma_floor: float = 0.01,
                              verbose: bool = True) -> Dict[str, Dict[str, float]]:
    
    dgp_train = KeyInstanceDGP(N=N_train, D=D, K=K, eta=eta, nu=nu, 
                               gamma_star=gamma_star, sigma_B=sigma_B, 
                               sigma_floor=sigma_floor)
    H_train, y_train, w_star = dgp_train.generate_data()
    
    dgp_val = KeyInstanceDGP(N=N_val, D=D, K=K, eta=eta, nu=nu,
                             gamma_star=gamma_star, sigma_B=sigma_B,
                             sigma_floor=sigma_floor)
    dgp_val.w_star = dgp_train.w_star
    dgp_val.w_perp = dgp_train.w_perp
    H_val, y_val, _ = dgp_val.generate_data()
    
    # Calculate metrics
    train_metrics = calculate_homogeneity_metrics(H_train, y_train)
    val_metrics = calculate_homogeneity_metrics(H_val, y_val)
    
    if verbose:
        print("="*60)
        print(f"HOMOGENEITY ANALYSIS")
        print(f"Data Parameters: η={eta:.3f}, ν={nu:.3f}, γ*={gamma_star:.2f}")
        print("="*60)
        
        print("\nTRAINING SET:")
        print(f"  In-bag dispersion (RMS):     {train_metrics['avg_in_bag_rms']:.4f}")
        
        print("\nVALIDATION SET:")
        print(f"  In-bag dispersion (RMS):     {val_metrics['avg_in_bag_rms']:.4f}")
    
    return {
        'train': train_metrics,
        'val': val_metrics
    }

# ---------------------------------------------------------------
# 1. Synthetic Data Generation Process
# ---------------------------------------------------------------

class KeyInstanceDGP:
    """
    Generates synthetic data.
    """
    def __init__(self, N=500, D=16, K=10, eta=1.0, nu=0.5, gamma_star=1.0, sigma_B=1.0, sigma_floor=0.01):
        self.N = N
        self.D = D
        self.K = K
        self.eta = eta
        self.nu = nu
        self.gamma_star = gamma_star
        self.sigma_B = sigma_B
        self.sigma_floor = sigma_floor
        self.w_star, self.w_perp = self._define_directions()

    def _define_directions(self):
        w_star = np.random.randn(self.D)
        w_star /= np.linalg.norm(w_star)
        w_perp = np.random.randn(self.D)
        w_perp -= np.dot(w_perp, w_star) * w_star
        if np.linalg.norm(w_perp) > 1e-8:
            w_perp /= np.linalg.norm(w_perp)
        else:
            for i in range(self.D):
                v = np.zeros(self.D)
                v[i] = 1
                v -= np.dot(v, w_star) * w_star
                if np.linalg.norm(v) > 1e-8:
                    w_perp = v / np.linalg.norm(v)
                    break
        return w_star, w_perp

    def generate_data(self):
        mu = np.random.randn(self.N, self.D) * self.sigma_B
        c = self.eta * np.abs(np.random.randn(self.N))
        H = np.zeros((self.N, self.K, self.D))
        for s in range(self.N):
            H[s, :-1, :] = mu[s] + np.random.randn(self.K-1, self.D) * self.sigma_floor
            H[s, -1, :] = (mu[s] + c[s] * self.w_star + self.nu * self.w_perp +
                          np.random.randn(self.D) * self.sigma_floor)
        eps_std = 0.6
        y = mu @ self.w_star + self.gamma_star * c + np.random.randn(self.N) * eps_std
        return H, y, self.w_star

# ---------------------------------------------------------------
# 2. ECA Implementation
# ---------------------------------------------------------------

class ScalingNetwork(nn.Module):
    def __init__(self, D, hidden_dim=None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = D // 4
        self.fc1 = nn.Linear(D, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

        nn.init.normal_(self.fc2.weight, std=0.01)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, mu_s):
        x = F.relu(self.fc1(mu_s))
        beta_s = self.fc2(x).squeeze(-1)
        return beta_s

class AttentionRegressorECA(nn.Module):
    def __init__(self, D,
                 gamma_max=2.0,
                 lambda_gamma=0.01,
                 use_sra=True,
                 T_min=0.1,
                 beta_temp=1.0,
                 use_dats=True,
                 init_scale=0.1,
                 device='cpu'):
        super().__init__()
        self.D = D
        self.device = device

        self.use_sra = use_sra
        self.gamma_max = gamma_max
        self.lambda_gamma = lambda_gamma

        self.use_dats = use_dats
        self.T_min = T_min
        self.beta_temp = beta_temp

        self.W_attn = nn.Parameter(torch.from_numpy(np.random.randn(D, 1) * init_scale))

        if self.use_sra:
            self.scaling_net = ScalingNetwork(D)

        self.W_reg = nn.Parameter(torch.from_numpy(np.random.randn(D, 1) * init_scale))
        self.b_reg = nn.Parameter(torch.zeros(1, dtype=torch.float64))

    def compute_dispersion(self, H, mu_s):

        deviations = H - mu_s
        squared_norms = (deviations ** 2).sum(dim=-1)
        rms_dispersion = torch.sqrt(squared_norms.mean(dim=1))
        return rms_dispersion

    def forward(self, H):

        batch_size, K, D = H.shape

        mu_s = H.mean(dim=1, keepdim=True)  # (batch_size, 1, D)

        Z = torch.einsum('bkd,do->bko', H, self.W_attn).squeeze(-1)  # (batch_size, K)

        if self.use_dats:
            D_s = self.compute_dispersion(H, mu_s)  # (batch_size,)

            tau_s = self.T_min + self.beta_temp * D_s  # (batch_size,)

            Z_scaled = Z / tau_s.unsqueeze(1)  # (batch_size, K)
        else:
            Z_scaled = Z

        Alpha = F.softmax(Z_scaled, dim=1)  # (batch_size, K)

        deviations = H - mu_s  # (batch_size, K, D)
        delta_v_s = torch.einsum('bk,bkd->bd', Alpha, deviations)  # (batch_size, D)

        if self.use_sra:
            beta_s = self.scaling_net(mu_s.squeeze(1))  # (batch_size,)
            gamma_s = 1.0 + F.softplus(beta_s)

            if self.gamma_max is not None:
                gamma_s = torch.clamp(gamma_s, max=self.gamma_max)

            V = mu_s.squeeze(1) + gamma_s.unsqueeze(1) * delta_v_s  # (batch_size, D)
        else:
            gamma_s = torch.ones(batch_size, dtype=torch.float64, device=H.device)
            V = mu_s.squeeze(1) + delta_v_s

        Y_hat = (V @ self.W_reg + self.b_reg).squeeze(-1)  # (batch_size)

        return Y_hat, Alpha, V, gamma_s

    def compute_gamma_regularization(self, gamma_s):

        if self.use_sra and self.lambda_gamma > 0:
            return self.lambda_gamma * ((gamma_s - 1.0) ** 2).mean()
        else:
            return 0.0

class DispersionNormalizedPCCLoss(nn.Module):

    def __init__(self, sigma_0=None):
        super().__init__()
        self.sigma_0 = sigma_0

    def forward(self, y_true, y_pred):

        y_true_centered = y_true - y_true.mean()
        y_pred_centered = y_pred - y_pred.mean()

        numerator = (y_true_centered * y_pred_centered).sum()
        sigma_y = torch.sqrt((y_true_centered ** 2).sum() + 1e-8)
        sigma_y_hat = torch.sqrt((y_pred_centered ** 2).sum() + 1e-8)

        pcc = numerator / (sigma_y * sigma_y_hat)

        if self.sigma_0 is None:
            sigma_0 = sigma_y.detach()
        else:
            sigma_0 = self.sigma_0

        scaling_factor = sigma_y_hat.detach() / sigma_0

        return (1 - pcc) * scaling_factor

class ECALoss(nn.Module):

    def __init__(self, lambda_pcc=1.0, sigma_0=None):
        super().__init__()
        self.lambda_pcc = lambda_pcc
        self.mse_loss = nn.MSELoss()
        self.pcc_loss = DispersionNormalizedPCCLoss(sigma_0=sigma_0)

    def forward(self, y_true, y_pred, gamma_s, lambda_gamma):

        loss_mse = self.mse_loss(y_pred, y_true)

        loss_pcc = self.pcc_loss(y_true, y_pred)

        loss_gamma = lambda_gamma * ((gamma_s - 1.0) ** 2).mean() if gamma_s is not None else 0

        total_loss = loss_mse + self.lambda_pcc * loss_pcc + loss_gamma

        return total_loss, loss_mse, loss_pcc, loss_gamma

# ---------------------------------------------------------------
# 3. Standard Attention Baseline
# ---------------------------------------------------------------

class AttentionRegressorStandard(nn.Module):

    def __init__(self, D, init_scale=0.1, device='cpu'):
        super().__init__()
        self.D = D
        self.device = device
        np.random.seed(42)
        self.W_attn = nn.Parameter(torch.from_numpy(np.random.randn(D, 1) * init_scale))
        self.W_reg = nn.Parameter(torch.from_numpy(np.random.randn(D, 1) * init_scale))
        self.b_reg = nn.Parameter(torch.zeros(1, dtype=torch.float64))

    def forward(self, H):
        Z = torch.einsum('nkd,do->nko', H, self.W_attn).squeeze(-1)
        Alpha = F.softmax(Z, dim=1)
        V = torch.einsum('nk,nkd->nd', Alpha, H)
        Y_hat = (V @ self.W_reg + self.b_reg).squeeze(-1)
        return Y_hat, Alpha, V, None 


class Optimizer:
    def __init__(self, params, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.params = list(params)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.t = 0
        self.m = [torch.zeros_like(p) for p in self.params]
        self.v = [torch.zeros_like(p) for p in self.params]

    def step(self):
        self.t += 1
        lr_t = self.lr * np.sqrt(1 - self.beta2**self.t) / (1 - self.beta1**self.t)
        with torch.no_grad():
            for i, param in enumerate(self.params):
                if param.grad is None:
                    continue
                self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * param.grad
                self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (param.grad**2)
                param.data -= lr_t * self.m[i] / (torch.sqrt(self.v[i]) + self.epsilon)

    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()

# ---------------------------------------------------------------
# 5. Training
# ---------------------------------------------------------------

def train_model(model, H_train, Y_train, H_val, Y_val,
                epochs=500, lr=0.001, lambda_pcc=1.0, use_eca_loss=False,
                use_dispersion_norm=True, device='cpu'):

    H_train_torch = torch.from_numpy(H_train).double().to(device)
    Y_train_torch = torch.from_numpy(Y_train).double().to(device)
    H_val_torch = torch.from_numpy(H_val).double().to(device)
    Y_val_torch = torch.from_numpy(Y_val).double().to(device)

    model = model.to(device)

    optimizer = Optimizer(model.parameters(), lr=lr)

    if use_eca_loss and use_dispersion_norm:
        sigma_0 = Y_train_torch.std()
        criterion = ECALoss(lambda_pcc=lambda_pcc, sigma_0=sigma_0).to(device)
    else:
        mse_loss = nn.MSELoss().to(device)
        pcc_loss = lambda loss_y_true, loss_y_pred: 1 - torch.sum((loss_y_true - loss_y_true.mean()) * (loss_y_pred - loss_y_pred.mean())) / (
            torch.sqrt(torch.sum((loss_y_true - loss_y_true.mean())**2) + 1e-8) *
            torch.sqrt(torch.sum((loss_y_pred - loss_y_pred.mean())**2) + 1e-8)
        )

    history = {'train_mse': [], 'val_mse': [], 'train_pcc': [], 'val_pcc': []}

    for epoch in range(epochs):
        model.train()
        Y_hat_train, _, _, gamma_s = model(H_train_torch)

        if use_eca_loss and use_dispersion_norm:
            loss, loss_mse, loss_pcc, loss_gamma = criterion(
                Y_train_torch, Y_hat_train, gamma_s,
                model.lambda_gamma if hasattr(model, 'lambda_gamma') else 0
            )
        else:
            loss_mse = mse_loss(Y_hat_train, Y_train_torch)
            loss_pcc = pcc_loss(Y_train_torch, Y_hat_train)
            loss = loss_mse + lambda_pcc * loss_pcc
            if hasattr(model, 'use_sra') and model.use_sra and gamma_s is not None:
                loss_gamma = model.lambda_gamma * ((gamma_s - 1.0) ** 2).mean()
                loss = loss + loss_gamma

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            Y_hat_train_eval, _, _, _ = model(H_train_torch)
            Y_hat_val, _, _, _ = model(H_val_torch)

            y_hat_train_np = Y_hat_train_eval.cpu().numpy()
            y_hat_val_np = Y_hat_val.cpu().numpy()

            train_mse = np.mean((Y_train - y_hat_train_np)**2)
            val_mse = np.mean((Y_val - y_hat_val_np)**2)

            train_pcc = np.corrcoef(Y_train, y_hat_train_np)[0, 1]
            val_pcc = np.corrcoef(Y_val, y_hat_val_np)[0, 1]

            history['train_mse'].append(train_mse)
            history['val_mse'].append(val_mse)
            history['train_pcc'].append(train_pcc)
            history['val_pcc'].append(val_pcc)

            if (epoch + 1) % 100 == 0 or epoch == 0:
                print(f"Epoch {epoch+1}/{epochs} | "
                      f"Train MSE: {train_mse:.4f}, PCC: {train_pcc:.4f} | "
                      f"Val MSE: {val_mse:.4f}, PCC: {val_pcc:.4f}")

    return history

# ---------------------------------------------------------------
# 6. Visualization
# ---------------------------------------------------------------

def plot_attention_vs_eca(history_std, history_eca, ETA, NU, GAMMA_STAR, LAMBDA_PCC, title_suffix="", save_path=None):
    import matplotlib.pyplot as plt

    epochs = range(1, len(history_std['train_mse']) + 1)
    fig, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Mean Squared Error (MSE)')

    l1, = ax1.plot(epochs, history_std['train_mse'], color='green', linestyle='--', label='Std-Train MSE')
    l2, = ax1.plot(epochs, history_std['val_mse'],   color='green', linestyle='-',  label='Std-Val MSE')
    l3, = ax1.plot(epochs, history_eca['train_mse'], color='red',   linestyle='--', label='ECA-Train MSE')
    l4, = ax1.plot(epochs, history_eca['val_mse'],   color='red',   linestyle='-',  label='ECA-Val MSE')

    ax1.grid(True, which="both", ls="--", alpha=0.5)

    ax2 = ax1.twinx()
    ax2.set_ylabel('Pearson Correlation Coefficient (PCC)')
    min_pcc = min(min(history_std['train_pcc']), min(history_std['val_pcc']),
                  min(history_eca['train_pcc']), min(history_eca['val_pcc']),)
    ax2.set_ylim([min_pcc - 0.05, 1.05])

    l5, = ax2.plot(epochs, history_std['train_pcc'], color='green', linestyle=':', label='Std-Train PCC')
    l6, = ax2.plot(epochs, history_std['val_pcc'],   color='green', linestyle='-.',  label='Std-Val PCC')
    l7, = ax2.plot(epochs, history_eca['train_pcc'], color='red',   linestyle=':', label='ECA-Train PCC')
    l8, = ax2.plot(epochs, history_eca['val_pcc'],   color='red',   linestyle='-.',  label='ECA-Val PCC')

    lines = [l1, l2, l3, l4, l5, l6, l7, l8]
    labels = [ln.get_label() for ln in lines]
    ax1.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=3, frameon=True)

    std_max_pcc_val   = max(history_std['val_pcc'])
    eca_max_pcc_val   = max(history_eca['val_pcc'])

    std_min_mse_val   = min(history_std['val_mse'])
    eca_min_mse_val   = min(history_eca['val_mse'])

    fig.tight_layout(rect=[0, 0, 1, 0.9]) 

    title_text = (
        f'Standard Attention vs ECA - Training Curves{title_suffix}\n'
        f'Homogeneity: in-bag {analyze_data_homogeneity(ETA, NU, GAMMA_STAR)["train"]["avg_in_bag_rms"]}\n'
        f'Val Max PCC: Std={std_max_pcc_val:.4f} | ECA={eca_max_pcc_val:.4f} \n'
        f'Val Min MSE: Std={std_min_mse_val:.4f} | ECA={eca_min_mse_val:.4f} '
    )
    plt.title(title_text, pad=20)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"Plot saved to {save_path}")
    else:
        plt.show()

# ---------------------------------------------------------------
# 7. Main Experiment
# ---------------------------------------------------------------

def main():
    CSV_FILE_PATH = 'data_generation_parameters.csv'
    OUTPUT_DIR = 'result_plots'

    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created directory: {OUTPUT_DIR}")

    try:
        with open(CSV_FILE_PATH, 'r', newline='') as f:
            reader = csv.reader(f)
            header = next(reader)

            for i, row in enumerate(reader):
                run_number = i + 1
                print(f"\n{'='*30} RUN {run_number} {'='*30}")

                param_str = "\t".join(row)
                print(f"Using parameters: {param_str}")

                N_train = 2000
                N_val   = 300
                D = 16
                K = 10
                EPOCHS  = 1500

                ECA_PARAMS = {
                    'use_sra': True,
                    'use_dats': True,
                    'use_dispersion_norm': True,
                    'device': device
                }

                (ECA_PARAMS["T_min"], ECA_PARAMS["beta_temp"], ETA,
                 ECA_PARAMS["gamma_max"], GAMMA_STAR, ECA_PARAMS["lambda_gamma"],
                 LAMBDA_PCC, LR, NU) = list(map(float, param_str.split()))

                print("\n" + "="*60)
                print(f"Running Experiment {run_number} on: {device}")
                print("="*60)

                print("\nGenerating synthetic data...")
                dgp_train = KeyInstanceDGP(N=N_train, D=D, K=K, eta=ETA, nu=NU, gamma_star=GAMMA_STAR)
                H_train, Y_train, w_star_train = dgp_train.generate_data()

                dgp_val = KeyInstanceDGP(N=N_val, D=D, K=K, eta=ETA, nu=NU, gamma_star=GAMMA_STAR)
                dgp_val.w_star = dgp_train.w_star
                dgp_val.w_perp = dgp_train.w_perp
                H_val, Y_val, _ = dgp_val.generate_data()
                print(f"Data generated. Train shape: {H_train.shape}, Val shape: {H_val.shape}")

                print("="*60)
                print("Training with Standard Attention...")
                print("="*60)
                model_standard = AttentionRegressorStandard(D=D, device=device)
                history_standard = train_model(
                    model_standard, H_train, Y_train, H_val, Y_val,
                    epochs=EPOCHS, lr=LR, lambda_pcc=LAMBDA_PCC,
                    use_eca_loss=False, use_dispersion_norm=False, device=device
                )

                print("\n" + "="*60)
                print("Training with ECA...")
                print("="*60)
                model_eca = AttentionRegressorECA(D=D, **{k: v for k, v in ECA_PARAMS.items() if k != 'use_dispersion_norm'})
                history_eca = train_model(
                    model_eca, H_train, Y_train, H_val, Y_val,
                    epochs=EPOCHS, lr=LR, lambda_pcc=LAMBDA_PCC,
                    use_eca_loss=True, use_dispersion_norm=ECA_PARAMS['use_dispersion_norm'], device=device
                )

                save_filename = os.path.join(OUTPUT_DIR, f'run_{run_number}_plot.png')
                plot_attention_vs_eca(
                    history_standard, history_eca,
                    ETA, NU, GAMMA_STAR, LAMBDA_PCC,
                    title_suffix=f" (Run {run_number})",
                    save_path=save_filename
                )

    except FileNotFoundError:
        print(f"Error: The file '{CSV_FILE_PATH}' was not found. Please make sure it's in the same directory.")
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == '__main__':
    main()