from utils import *
from models.gnn_models import *
from models.egnn_models import *
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from torch_geometric.loader import DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import argparse

set_random_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description="Link Prediction with Various GNN Models")
parser.add_argument('--model', type=str, default='gcn',
                    choices=['gcn', 'sage', 'gat', 'egnn', 'equiformer'])
parser.add_argument('--hid_dim', type=int, default=32)
parser.add_argument('--r', type=float, default=0.1, help="Ratio of edge removal for link prediction")
parser.add_argument('--dataset', type=str, default='nanofibres', choices=['nanofibres', 'nanowire', 'polymer'], help="Dataset to use")
args = parser.parse_args()

np.set_printoptions(threshold=np.inf)


train_ratio = 0.6
val_ratio = 0.2
test_ratio = 0.2


def get_model(args, in_dim=32):
    if args.model == 'gcn':
        model = GCN(in_channels=in_dim, hidden_channels=args.hid_dim).to(device)    
    elif args.model == 'sage':
        model = GraphSAGE(in_channels=in_dim, hidden_channels=args.hid_dim).to(device)
    elif args.model == 'gat':
        model = GAT(in_channels=in_dim, hidden_channels=args.hid_dim, heads=4).to(device)
    elif args.model == 'egnn':
        model = BatchedEGNNLinkPredictor(
            in_channels=in_dim, 
            hidden_channels=args.hid_dim,
            batch_size=64,  # Adjust this based on your GPU memory
            dropout=0.1
        ).to(device)
    elif args.model == 'equiformer':
        model = BatchedEquiformerV2LinkPredictor(
            in_channels=in_dim,
            hidden_channels=args.hid_dim,
            num_layers=2,
            max_ell=2,
            dropout=0.1,
            batch_size=2048  # Adjust based on your GPU memory
        ).to(device)
    else:
        raise ValueError(f"Unknown model type: {args.model}")

    lr = 0.1 if args.model == 'equiformer' else 0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, optimizer



def train_epoch(model, ddata, optimizer):

    model.train()
    optimizer.zero_grad()
    z = model.encode(ddata['all'].x.to(device), ddata['all'].edge_index.to(device))
    out = model.decode_break(z, ddata['train_edge'].to(device))
    
    label = ddata['train_y'].to(device)
    loss = F.mse_loss(out, label)
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test_epoch(model, ddata, test=False):
    model.eval()
   
    if test:
        index = ddata['test_edge']
        label = ddata['test_y']
    else:
        index = ddata['val_edge']
        label = ddata['val_y']


    z = model.encode(ddata['all'].x.to(device), ddata['all'].edge_index.to(device))
    out = model.decode_break(z, index.to(device))
    
    # acc = (pred == data.edge_label.to(device)).sum().item() / pred.size(0)

    out = out.cpu().numpy()

    mse = mean_squared_error(label, out)
    mae = mean_absolute_error(label, out)

    return mse, mae 



def run_experiment(ddata):
    """Run one experiment with a given random state"""

    # Set random seed


    # Get model and optimizer
    model, optimizer = get_model(args, in_dim=ddata['all'].num_node_features)

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    
    # Training
    best_val_loss = float('inf')
    patience = 50
    epochs_no_improve = 0
    best_model_state = None
    max_epochs = 500

    for epoch in range(max_epochs):
        train_loss = train_epoch(model, ddata, optimizer)
        val_mse, val_mae = test_epoch(model, ddata)

        # scheduler.step(val_loss)
        
        # if epoch % 10 == 0:
        #     print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, '
        #           f'Val MSE: {val_mse:.4f}, Val MAE: {val_mae:.4f}')

        if val_mse < best_val_loss:
            best_val_loss = val_mse
            best_model_state = model.state_dict().copy()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            
        if epochs_no_improve >= patience:
            # print(f"Early stopping at epoch {epoch}. Best Val Loss: {best_val_loss:.4f}")
            break
    
    # Load best model and test
    model.load_state_dict(best_model_state)
    test_mse, test_mae = test_epoch(model, ddata, test=True)
    
    return {
        'test_mse': test_mse,
        'test_mae': test_mae}

def main():
    
    # Load data
    data_list = get_polymer_network()
    all_mae = []

    
    # Use the largest graph for link prediction
    # data = max(data_list, key=lambda data: data.num_edges)
    # print(f"Largest graph has {data.num_nodes} nodes and {data.num_edges} edges.")

    

    def create_data_splits(data, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2):
        # Create splits
        y = data.y.cpu().numpy()
        unique_vals, counts = np.unique(y, return_counts=True)
        # print("Value counts in data.y:")
        # for val, cnt in zip(unique_vals, counts):
        #     print(f"{val}: {cnt}")

        idx_by_val = {val: np.where(y == val)[0] for val in unique_vals}
        min_count = min([len(idx) for idx in idx_by_val.values()])

        # For each value, shuffle and split indices
        train_idx, val_idx, test_idx = [], [], []
        
        max_count = 100
        for val in unique_vals:
            idx = idx_by_val[val]
            np.random.shuffle(idx)
            n_train = int(train_ratio * min_count)
            n_val = int(val_ratio * min_count)
            n_test = int(test_ratio * min_count)

            train_idx.extend(idx[:min(n_train, max_count)])
            val_idx.extend(idx[n_train:n_train + min(n_val, max_count)])
            test_idx.extend(idx[n_train + n_val:n_train + n_val + min(n_test, max_count)])

        # Shuffle final indices
        np.random.shuffle(train_idx)
        np.random.shuffle(val_idx)
        np.random.shuffle(test_idx) 
        return train_idx, val_idx, test_idx

    # train_idx, val_idx, test_idx = create_data_splits(data, train_ratio, val_ratio, test_ratio)
    # num_edges = data.edge_index.size(1)
    # idx = list(range(num_edges))
    # random.shuffle(idx)
    # train_end = int(train_ratio * num_edges)
    # val_end = train_end + int(val_ratio * num_edges)
    # train_idx = idx[:train_end]
    # val_idx = idx[train_end:val_end]
    # test_idx = idx[val_end:]



    for idx, data in enumerate(data_list):
        print(f"\n=== Graph {idx+1}/{len(data_list)}: {data.num_nodes} nodes, {data.num_edges} edges ===")

        train_idx, val_idx, test_idx = create_data_splits(data, train_ratio, val_ratio, test_ratio)

        ddata = {
            'all': data,
            'train_edge': data.edge_index[:, train_idx],
            'val_edge': data.edge_index[:, val_idx],
            'test_edge': data.edge_index[:, test_idx],
            'train_y': data.y[train_idx],
            'val_y': data.y[val_idx],
            'test_y': data.y[test_idx]
        }


    #     # Run multiple experiments
    #     num_runs = 3
    #     results = {'mse': [], 'mae': [], 'r2': []}
        
    #     for run in range(num_runs):
    #         # print(f"\n=== Run {run + 1}/{num_runs} ===")
    #         result = run_experiment(ddata)
            
    #         results['mse'].append(result['test_mse'])
    #         results['mae'].append(result['test_mae'])
    #         # results['r2'].append(result['test_r2'])
            
    #         # print(f"Test MSE: {result['test_mse']:.4f}")
    #         # print(f"Test MAE: {result['test_mae']:.4f}")
    #         # print(f"Test R²: {result['test_r2']:.4f}")
    
    # # Print final results
    # print(f"\n=== Final Results (Mean ± Std over {num_runs} runs) ===")
    # print(f"Test MSE: {np.mean(results['mse']):.4f} ± {np.std(results['mse']):.4f}")
    # print(f"Test MAE: {np.mean(results['mae']):.4f} ± {np.std(results['mae']):.4f}")
    # # print(f"Test R²: {np.mean(results['r2']):.4f} ± {np.std(results['r2']):.4f}")


        num_runs = 3
        maes = []
        for run in range(num_runs):
            result = run_experiment(ddata)
            maes.append(result['test_mae'])
        mean_mae = np.mean(maes)
        std_mae = np.std(maes)
        all_mae.append(mean_mae)
        print(f"Graph {idx+1}: Mean Test MAE = {mean_mae:.4f} ± {std_mae:.4f}")

    print(f"\n=== Overall Mean MAE across all graphs: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f} ===")



if __name__ == "__main__":
    main()