import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from copy import deepcopy
from model.utils import init_base_model

class EarlyStopping:
    def __init__(self, patience, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_val_acc = -float('inf')
        self.best_model_state = None

    def __call__(self, val_acc, model):
        score = val_acc
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        if self.verbose:
            print(f'Validation accuracy increased ({self.best_val_acc:.6f} --> {val_acc:.6f}). Saving model ...')
        self.best_val_acc = val_acc
        self.best_model_state = deepcopy(model.state_dict())


def train_epoch(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(F.log_softmax(out, dim=1)[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, data):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = F.log_softmax(out, dim=1).argmax(dim=1)
    accs = [(pred[mask] == data.y[mask]).float().mean().item() for mask in [data.train_mask, data.val_mask, data.test_mask]]
    return accs

def run_training(model, optimizer, data, epochs, patience, save_path=None):
    train_losses, train_accuracies, val_accuracies, test_accuracies = [], [], [], []
    early_stopping = EarlyStopping(patience=patience)

    with tqdm(total=epochs, desc="Training Progress") as pbar:
        for epoch in range(epochs):
            loss = train_epoch(model, optimizer, data)
            train_acc, val_acc, test_acc = test(model, data)

            train_losses.append(loss)
            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            test_accuracies.append(test_acc)

            pbar.set_postfix({
                'Loss': f'{loss:.4f}', 
                'Train Acc': f'{train_acc:.4f}', 
                'Val Acc': f'{val_acc:.4f}', 
                'Test Acc': f'{test_acc:.4f}'
            })
            pbar.update(1)

            early_stopping(val_acc, model)
            if early_stopping.early_stop:
                print(f"No performance improvement for {patience} epochs. Stopping early.")
                break
    
    if save_path is not None:
        torch.save(early_stopping.best_model_state, save_path)

    # Load the best model state
    model.load_state_dict(early_stopping.best_model_state)
    _, val_acc, test_acc = test(model, data)
    print(f"Best model validation accuracy: {val_acc:.4f}")
    print(f"Best model test accuracy: {test_acc:.4f}")

    return train_losses, train_accuracies, val_accuracies, test_accuracies

def train(data, params, device, save_path, architecture):
    data = data.to(device)

    # Initialize the model using init_model
    model = init_base_model(data, params, device, architecture)

    optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])
    original_results = run_training(model, optimizer, data, params["epochs"], params["patience"], save_path)

    return model, original_results