import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch_geometric.data import DataLoader
import numpy as np
from torch.autograd import grad
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def compute_ma(input):
    return torch.mean(torch.abs(input)).cpu().item()

def run(train_dataset, val_dataset, test_dataset, save_dir, log_dir, model, epochs, batch_size, lr, lr_decay_factor, lr_decay_step_size, weight_decay, 
        energy_and_force, p):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=lr_decay_step_size, gamma=lr_decay_factor)
    loss_func = torch.nn.L1Loss()
    metric_func = compute_ma

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
    best_val = float('inf')
    best_test = float('inf')
        
    if save_dir != '':
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
    if log_dir != '':
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        writer = SummaryWriter(log_dir=log_dir)
    
    for epoch in range(1, epochs + 1):
        print("\n=====Epoch {}".format(epoch), flush=True)

        train_loss = train(model, optimizer, train_loader, energy_and_force, p, loss_func, device)
        val_metric = val(model, val_loader, energy_and_force, p, metric_func, device)
        test_metric = val(model, test_loader, energy_and_force, p, metric_func, device)

        print({'Train': train_loss, 'Validation': val_metric, 'Test': test_metric})

        if log_dir != '':
            writer.add_scalar('train_loss', train_loss, epoch)
            writer.add_scalar('val_metric', val_metric, epoch)
            writer.add_scalar('test_metric', test_metric, epoch)
        
        if val_metric < best_val:
            best_val = val_metric
            best_test = test_metric
            if save_dir != '':
                print('Saving checkpoint...')
                checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_valid_mae': best_val}
                torch.save(checkpoint, os.path.join(save_dir, 'valid_checkpoint.pt'))

        scheduler.step()

    print(f'Best validation MAE so far: {best_val}')
    print(f'Test MAE when got best validation result: {best_test}')
    
    
    if log_dir != '':
        writer.close()



def train(model, optimizer, train_loader, energy_and_force, p, loss_func, device):
    model.train()
    loss_accum = 0
    for step, batch_data in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        batch_data = batch_data.to(device)
        out = model(batch_data)
        if energy_and_force:
            force = -grad(outputs=out, inputs=batch_data.pos, grad_outputs=torch.ones_like(out),create_graph=True,retain_graph=True)[0]
            e_loss = loss_func(out, batch_data.y.unsqueeze(1))
            f_loss = loss_func(force, batch_data.force)
            loss = e_loss + p * f_loss
        else:
            loss = loss_func(out, batch_data.y.unsqueeze(1))
        loss.backward()
        optimizer.step()
        loss_accum += loss.detach().cpu().item()
    return loss_accum / (step + 1)

def val(model, data_loader, energy_and_force, p, metric_func, device):
    model.eval()

    preds = torch.Tensor([]).to(device)
    targets = torch.Tensor([]).to(device)

    if energy_and_force:
        preds_force = torch.Tensor([]).to(device)
        targets_force = torch.Tensor([]).to(device)
    
    for step, batch_data in enumerate(tqdm(data_loader)):
        batch_data = batch_data.to(device)
        out = model(batch_data)
        if energy_and_force:
            force = -grad(outputs=out, inputs=batch_data.pos, grad_outputs=torch.ones_like(out),create_graph=True,retain_graph=True)[0]
            preds_force = torch.cat([preds_force,force.detach_()], dim=0)
            targets_force = torch.cat([targets_force,batch_data.force], dim=0)
        preds = torch.cat([preds, out.detach_()], dim=0)
        targets = torch.cat([targets, batch_data.y.unsqueeze(1)], dim=0)


    if energy_and_force:
        energy_mae = metric_func(targets-preds)
        force_mae = metric_func(targets_force-preds_force)
        print({'Energy MAE': energy_mae, 'Force MAE': force_mae})
        return energy_mae + p * force_mae

    return metric_func(targets-preds)