import os

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
# from torch.optim import ReduceLROnPlateau

import torch_geometric.datasets as tg_dataset
import torch_geometric
from torch_geometric.datasets import PCQM4Mv2
from torch_geometric.utils import dropout_node, dropout_edge
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import AddRandomWalkPE

from rdkit.Chem import MolFromSmiles, GetPeriodicTable

from tqdm import tqdm

TABLE = GetPeriodicTable()

from model import *

# class PubChem(PCQM4Mv2):
#     def __init__()

class PretrainingGraphDataset(Dataset):
    def __init__(self, graphs, masking_ratio=0.4, pretrain_type='masking', transform=AddRandomWalkPE(20)):
        self.graphs = graphs
        self.pretrain_type = pretrain_type
        self.masking_ratio = masking_ratio
        self.transform = transform
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, index):
        molecule = MolFromSmiles(self.graphs[index].smiles)
        atom_type = torch.LongTensor([atom.GetAtomicNum() for atom in molecule.GetAtoms()])
        
        graph = Data(x=atom_type, edge_index=self.graphs[index].edge_index)
        
        edge_index_1, _, node_mask_1 = dropout_node(graph.edge_index, p=self.masking_ratio, num_nodes=graph.num_nodes, relabel_nodes=True)
        edge_index_2, _, = dropout_edge(graph.edge_index, p=self.masking_ratio, force_undirected=True)
        
        return {
                'graph1': self.transform(Data(x=atom_type[node_mask_1], edge_index=edge_index_1)),
                'graph2': self.transform(Data(x=atom_type, edge_index=edge_index_2))
                }
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_kwargs = {'dropout': 0.5}
model = GPS(channels=768, pe_dim=64, num_atom_type=119, num_layers=12, tau=0.2, attn_type='multihead', attn_kwargs=attn_kwargs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
#                             min_lr=0.00001)
if os.path.exists('./dataset/raw/pcqm4m-v2/fitered_pubchem.pt'):
    filtered_pubchem = torch.load('./dataset/raw/pcqm4m-v2/fitered_pubchem.pt')
else:
    pubchem = torch_geometric.datasets.PCQM4Mv2(root='./dataset/')
    filtered_pubchem = []
    for data in tqdm(pubchem):
        molecule = MolFromSmiles(data.smiles)
        try:
            molecule.GetAtoms()
            filtered_pubchem.append(data) 
        except:
            continue
    print(f'Final number of training samples: {len(filtered_pubchem):02d}')
    torch.save(filtered_pubchem, './dataset/raw/pcqm4m-v2/fitered_pubchem.pt')
dataset = PretrainingGraphDataset(filtered_pubchem)
train_loader = DataLoader(dataset, batch_size=512, shuffle=True)

def train():
    model.train()
    model.tau = 0.7
    for epoch in range(1, 101):
        total_loss = 0
        loop = tqdm(train_loader, total=len(train_loader))
        for data in loop:
            graph1, graph2 = data['graph1'].to(device), data['graph2'].to(device)
            print(graph1)
            print(graph2)
            # exit()
            optimizer.zero_grad()
            model.redraw_projection.redraw_projections()
            
            loss = model.cl_loss(graph1, graph2)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            total_loss += loss.item()
            optimizer.step()
            # print(f'Step {i+1:02d}, Loss: {loss.item():.4f}')
            loop.set_description(f'Epoch [{epoch}/{100}]')
            loop.set_postfix(loss = loss.item())
        torch.save(model.state_dict(), './checkpoints/GPS/model.pth')
        model.tau *= 0.9
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                    data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

train()
# for epoch in range(1, 101):
#     loss = train()
#     # scheduler.step(loss)
#     print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
    
#     torch.save(model.state_dict(), './checkpoints/GPS/model.pth')