import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler

from src.main_model import DiffPO
from src.utils import train, train_sid, evaluate, evaluate_sid
import wandb

wandb.init(project="DiffPO", notes="synthetic experiment")
# === DEVICE ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--case', type=str, choices=['Case1', 'Case2', 'Case3', 'Case4'], default='Case1')
args = parser.parse_args()

def normalize_with_standardscaler(X, y):
    scaler_X = StandardScaler()
    scaler_y = StandardScaler()

    X_norm = scaler_X.fit_transform(X)
    y_norm = scaler_y.fit_transform(y.reshape(-1, 1)).flatten()

    return X_norm, y_norm

# === SYNTHETIC DATA GENERATORS ===
def generate_case1(n=1000, d=5):
    X = np.random.normal(0,1,(n,d))
    z = np.random.binomial(1,0.1,n) # treated very rare
    y = 2*X[:,0] +3*z +0.5*X[:,0]*z + np.random.normal(0,0.1,n)
    X, y = normalize_with_standardscaler(X, y)
    return X, z, y

def generate_case2(n=1000, d=5):
    X = np.random.normal(0,1,(n,d))
    alpha = np.where(X[:,0]>0,0.1,0.5)
    z = np.random.beta(alpha,0.5,n) # treatment depending on covariate X
    y = 2*X[:,0] +3*z +0.5*X[:,0]*z + np.random.normal(0,0.1,n)
    X, y = normalize_with_standardscaler(X, y)
    return X, z, y

def generate_case3(n=1000, d=5):
    z = np.random.binomial(1,0.5,n)
    X = np.where(z[:,np.newaxis]==1,np.random.normal(1,1,(n,d)),np.random.normal(-1,1,(n,d))) # Covariate distribution depends on treatment group
    y = 2*X[:,0] +3*z +0.5*X[:,0]*z + np.random.normal(0,0.1,n)
    X, y = normalize_with_standardscaler(X, y)
    return X, z, y



def generate_case4(n=1000, d=5):
    # First covariate is shifted uniformly depending on treatment, others are random Gaussian → strong confounding on X_1.
    z = np.random.binomial(1,0.5,n)
    x0 = np.where(z==1,np.random.uniform(0,2,n),np.random.uniform(-2,0,n))
    x_rest = np.random.normal(0,1,(n,d-1))
    X = np.column_stack((x0,x_rest))
    y = 2*X[:,0] +3*z +0.5*X[:,0]*z + np.random.normal(0,0.1,n)
    X, y = normalize_with_standardscaler(X, y)
    return X, z, y

def train_test_split(observed_data, test_ratio=0.2):
    n = observed_data["observed_data"].shape[0]
    idx = np.random.permutation(n)
    split = int(n * (1 - test_ratio))
    train_idx, test_idx = idx[:split], idx[split:]
    train_data = {k: v[train_idx] for k, v in observed_data.items()}
    test_data = {k: v[test_idx] for k, v in observed_data.items()}
    return train_data, test_data

# === BUILD OBSERVED DATA ===
def prepare_observed_data(X, z, y):
    n, d = X.shape
    observed_data = np.column_stack([z, y, y, y, y, X])  # z, y0, y1, mu0, mu1, X
    observed_tensor = torch.tensor(observed_data, dtype=torch.float32)
    observed_mask = torch.ones_like(observed_tensor)  # all ones
    gt_mask = torch.ones_like(observed_tensor)        # all ones
    timepoints = torch.arange(1).repeat(n, 1)        # fake timepoints

    return {
        "observed_data": observed_tensor,
        "observed_mask": observed_mask,
        "gt_mask": gt_mask,
        "timepoints": timepoints
    }

class DictTensorDataset(torch.utils.data.Dataset):
    def __init__(self, data_dict):
        self.data_dict = data_dict
        self.length = data_dict["observed_data"].shape[0]  # number of samples

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # For each key (e.g., 'observed_data', 'observed_mask', ...),
        # take the idx-th sample
        return {key: value[idx] for key, value in self.data_dict.items()}

def create_dataloader(observed_data, batch_size=128):
    ds = DictTensorDataset(observed_data)
    return DataLoader(ds, batch_size=batch_size, shuffle=True)

def plot_y_distributions(pretrain_model, sid_model, test_loader, device, case_name):
    pretrain_model.eval()
    sid_model.eval()
    y_true = []
    y_pretrain = []
    y_sid = []

    with torch.no_grad():
        for batch in test_loader:
            obs_data = batch["observed_data"].to(device)
            # True y
            y_true_batch = obs_data[:, 1].cpu().numpy()  # column 1 is y

            # Pretrain predictions
            pretrain_samples, _, _, _, _ = evaluate(batch, nsample=50)
            pretrain_pred = torch.median(pretrain_samples, dim=1).values[:, 0].cpu().numpy()

            # SID predictions
            sid_samples, _, _, _, _ = evaluate_sid(batch, nsample=50)
            sid_pred = torch.median(sid_samples, dim=1).values[:, 0].cpu().numpy()

            y_true.append(y_true_batch)
            y_pretrain.append(pretrain_pred)
            y_sid.append(sid_pred)

    y_true = np.concatenate(y_true)
    y_pretrain = np.concatenate(y_pretrain)
    y_sid = np.concatenate(y_sid)

    plt.figure(figsize=(10,6))
    sns.kdeplot(y_true, label='True y', linewidth=2)
    sns.kdeplot(y_pretrain, label='Pretrain Pred y', linewidth=2)
    sns.kdeplot(y_sid, label='SID Pred y', linewidth=2)
    plt.title(f'Distribution of y on {case_name}')
    plt.xlabel('y')
    plt.ylabel('Density')
    plt.legend()
    plt.savefig(f'{case_name}_y_distributions.png')
    plt.show()

# === MAIN EXPERIMENT LOOP ===
def run_experiment(config, observed_data, case_name):
    train_data, test_data = train_test_split(observed_data, test_ratio=0.2)
    train_loader = create_dataloader(train_data)
    test_loader = create_dataloader(test_data)

    # Pretrain
    pretrain_model = DiffPO(config, device).to(device)
    train(pretrain_model, config["train"], train_loader)

    pretrain_path = f'pretrain_{case_name}.pth'
    torch.save(pretrain_model.state_dict(), pretrain_path)

    # SID training
    sid_model = train_sid(DiffPO, config, pretrain_path, num_epochs=config['train']['epochs'],
                          train_loader=train_loader, valid_loader=None, device=device)

    # Evaluate on test set
    pretrain_metrics = evaluate(pretrain_model, test_loader, nsample=50)
    sid_metrics = evaluate_sid(pretrain_model, sid_model, test_loader, nsample=50, device=device)

    plot_y_distributions(pretrain_model, sid_model, create_dataloader(test_data), device, case_name)


    return pretrain_metrics, sid_metrics

# === MAIN SCRIPT ===
if __name__ == '__main__':
    config = {
        "dataset": {
            "data_name": "synthetic"
        },
        "train": {
            "epochs": 2000,
            "batch_size": 512,
            "lr": 0.0001,
            "valid_epoch_interval": 200
        },
        "diffusion": {
            "layers": 4,
            "channels": 64,
            "f_dim": 180,
            "cond_dim": 6,
            "hidden_dim": 128,
            "side_dim": 33,
            "nheads": 2,
            "diffusion_embedding_dim": 128,
            "beta_start": 0.0001,
            "beta_end": 0.5,
            "num_steps": 100,
            "schedule": "quad",
            "mixed": False
        },
        "model": {
            "is_unconditional": 0,
            "timeemb": 32,
            "featureemb": 32,
            "target_strategy": "random",
            "mixed": False
        }
    }

    cases = {
        'Case1': generate_case1,
        'Case2': generate_case2,
        'Case3': generate_case3,
        'Case4': generate_case4
    }

    # Select single case
    case_name = args.case
    gen_func = cases[case_name]

    print(f'\n=== Running {case_name} ===')
    X, z, y = gen_func()
    observed_data = prepare_observed_data(X, z, y)

    pretrain_metrics, sid_metrics = run_experiment(config, observed_data, case_name)



    results = []
    results.append({
        'Case': case_name, 'Model': 'Pretrain',
        'RMSE_y0': pretrain_metrics[0],
        'RMSE_y1': pretrain_metrics[1],
        'PEHE': pretrain_metrics[2]
    })
    results.append({
        'Case': case_name, 'Model': 'SID',
        'RMSE_y0': sid_metrics['IWDD RMSE_y0'],
        'RMSE_y1': sid_metrics['IWDD RMSE_y1'],
        'PEHE': sid_metrics['IWDD PEHE']
    })

    # === PLOT RESULTS ===
    df = pd.DataFrame(results)

    for metric in ['RMSE_y0', 'RMSE_y1', 'PEHE']:
        plt.figure()
        sns.barplot(data=df, x='Model', y=metric)
        plt.title(f'{metric} on {case_name}')
        plt.savefig(f'{case_name}_{metric}.png')
        plt.show()