import torch
from utils.mol_utils import smiles_to_mols
import numpy as np
import json 
import pandas as pd
from torch_geometric.data import Batch, Data
from torch.utils.data import DataLoader
from torch import Tensor
import copy 
import networkx as nx
import time 
import argparse
import sys
from reg_models import LinearGC, SGCReg
from find_metrics import find_metric

ATOM_ID = {'C': 0, 'N': 1, 'O': 2, 'F': 3, 'P': 4, 'S': 5, 'Cl': 6, 'Br': 7, 'I': 8}
ID_ATOM = {0: 'C', 1: 'N', 2: 'O', 3: 'F', 4: 'P', 5: 'S', 6: 'Cl', 7: 'Br', 8: 'I'}
ATOM_AN = {'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}

return_adjs=True

class MolData ():
    def __init__ (self, dataset, prop_name, return_adjs=True, device='cpu'): 
        # data, smiles_col, extra_features=False, device='cpu'):
        # 
        self.dataset = dataset
        df = pd.read_csv(f'data/{dataset.lower()}.csv')

        if dataset == 'qm9':
            smiles_col = 'SMILES1'
            self.atoms = ['C', 'N', 'O', 'F']
            self.max_nodes = 9
        elif dataset == 'zinc250k':
            smiles_col = 'smiles'
            self.atoms = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I']
            self.max_nodes = 38
        else:
            raise ValueError(f"[ERROR] Unexpected value data_name={dataset}")

        self.smiles = df[smiles_col]
        self.ys = torch.tensor(df[prop_name]) if prop_name in df else torch.tensor (find_metric(self.smiles, prop_name))
        self.device = device
        self._df_ = df
        self.return_adjs = return_adjs

        self.__create_graphs__()

    @property
    def num_node_features (self):
        return self.graphs[0].num_node_features

    def __get_node_attributes__ (self, atom, features):
        x = np.zeros(len(self.atoms), dtype=float)
        x[ATOM_ID[atom.GetSymbol()]] = 1
        # feat_id = len(self.atoms)
        # feature_transforms = {
        #     'atomic_num': lambda x: x.GetAtomicNum(),
        #     'formal_charge': lambda x: x.GetFormalCharge(),
        #     'chiral_tag': lambda x: x.GetChiralTag(),
        #     'hybridization': lambda x: x.GetHybridization(),
        #     'num_explicit_hs': lambda x: x.GetNumExplicitHs(),
        #     'is_aromatic': lambda x: x.GetIsAromatic()
        # }
        # for i, feat in enumerate(features):
        #     x[feat_id+i] = feature_transforms[feat](atom)
        return x

    def __create_graphs__ (self, features=['atomic_num', 'formal_charge', 'chiral_tag', 'hybridization', 'num_explicit_hs', 'is_aromatic']):
        mols = smiles_to_mols (self.smiles)
        self.graphs = []
        for mol, y in zip (mols, self.ys):
            x = []
            start_time = time.time()
            for atom in sorted(mol.GetAtoms(), key=lambda x: x.GetIdx()):
                x.append (self.__get_node_attributes__(atom, features))
            x = torch.tensor (np.stack(x))
            edge_index, edge_attr = [], []
            for bond in mol.GetBonds():
                edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
                edge_attr.append([bond.GetBondTypeAsDouble()])
            edge_index = torch.tensor(edge_index, dtype=torch.long).T
            edge_attr = torch.tensor(edge_attr).view(-1)
            self.graphs.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y).to(self.device))

    def __getitem__(self, idx):
        if (type(idx) is int or isinstance (idx, np.integer) 
            or (isinstance(idx, Tensor) and idx.dim() == 0)
            or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
            return self.graphs[idx]
        else:
            dataset = copy.copy (self)
            dataset.graphs = [dataset.graphs[ind] for ind in idx]
            return dataset

    def __len__(self):
        return len(self.graphs)

    def change_prop (self, new_prop):
        for i in range(len(self.graphs)):
            self.graphs[i].y = torch.tensor(self._df_.loc[i, new_prop]).to(self.device)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='qm9')
    parser.add_argument('--props', type=str, nargs='+')
    parser.add_argument('--sgc_nlayers', type=int, default=2)
    parser.add_argument('--nepochs', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--analyze', action='store_true')
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--device', type=str, default='cuda:0')
    args = parser.parse_args()

    DATASET = args.dataset
    DEVICE = args.device
    PROPS = args.props
    nlayers = args.sgc_nlayers
    nepochs = args.nepochs
    batch_size = args.batch_size
    lr = args.lr

    data = MolData (dataset=DATASET, prop_name=PROPS[0], device=DEVICE)
    with open(f'data/valid_idx_{DATASET.lower()}.json') as f:
        test_idx = json.load(f)
    if DATASET == 'qm9':
        test_idx = test_idx['valid_idxs']
        test_idx = [int(i) for i in test_idx]

    test_idx = np.array(test_idx)
    all_mask = np.ones(len(data), dtype=bool)
    all_mask[test_idx] = 0
    train_idx = np.where (all_mask)[0]

    train_data = data[train_idx]
    test_data = data[test_idx]

    if args.analyze:
        collate_fn=lambda x: Batch.from_data_list(x)
        for i, prop in enumerate (PROPS):
            model = SGCReg(data.num_node_features, num_layers=nlayers).to(DEVICE)
            model_save_fname = f'config/constraints/regmodels/sgc{nlayers}_{DATASET}_{prop}.pt'
            model.load_state_dict (torch.load(model_save_fname, map_location=torch.device(DEVICE)))
            train_loader = DataLoader (train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) #lambda x: Batch.from_data_list(x))
            test_loader = DataLoader (test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
            train_mae = 0.0
            for datum in train_loader:
                model.eval()
                out = model(datum.x, datum.edge_index, datum.edge_attr, datum.batch)
                loss = (out - datum.y).abs().sum().item()
                train_mae += loss
            print (prop, (len(train_loader)*batch_size), train_mae/(len(train_loader.dataset)))
            test_mae = 0.0
            for datum in test_loader:
                model.eval()
                out = model(datum.x, datum.edge_index, datum.edge_attr, datum.batch)
                loss = (out - datum.y).abs().sum().item()
                test_mae += loss
            print (prop, (len(test_loader)*batch_size), test_mae/(len(test_loader.dataset)))
        exit()

    for i, prop in enumerate (PROPS):
        if i > 0:
            train_data.change_prop(prop)
            test_data.change_prop(prop)

        model = SGCReg(data.num_node_features, num_layers=nlayers).to(DEVICE)

        def collate_fn (x):
            return Batch.from_data_list(x)

        train_loader = DataLoader (train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) #lambda x: Batch.from_data_list(x))
        test_loader = DataLoader (test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = torch.nn.MSELoss()

        def train():
            model.train()
            av_loss = 0
            for data in train_loader:
                out = model(data.x, data.edge_index, data.edge_attr, data.batch)
                loss = criterion(out, data.y)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                av_loss += loss

            return (av_loss/len(train_loader.dataset))

        def test(loader, metric='mse'):
            model.eval()
            loss = 0
            for data in loader:
                pred = model(data.x, data.edge_index, data.edge_attr, data.batch)  
                if metric == 'mse':
                    loss += ((pred - data.y)**2).sum()
                elif metric == 'mae':
                    loss += (pred - data.y).abs().sum()
            return loss / (len(loader.dataset) * batch_size)


        for epoch in range(1, nepochs):
            train_loss = train()
            test_loss = test(test_loader)
            test_mae = test(test_loader, metric='mae')
            train_mae = test(train_loader, metric='mae')
            print(f'Epoch: {epoch:03d}, Train MSE: {train_loss:.4f}, Test MSE: {test_loss:.4f}, Train MAE: {train_mae:.4f}, Test MSE: {test_mae:.4f}')

        model_save_fname = f'config/constraints/regmodels/sgc{nlayers}_{DATASET}_{prop}.pt'
        torch.save (model.state_dict(), model_save_fname)