import argparse
import torch
from torch import nn
import numpy as np
from data import get_dataset
import pandas as pd
import pickle as pk
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from jarvis.core.atoms import Atoms
from torch.utils.data import DataLoader
from tqdm import tqdm
# import wandb
from e3nn.io import CartesianTensor
from pandarallel import pandarallel
from data import get_symmetry_dataset
pandarallel.initialize(progress_bar=False)
import gc
from graphs import atoms2graphs, GraphDataset
from utils import get_id_train_val_test
from ceitnet import CEITNet
import matplotlib.pyplot as plt
from e3nn import o3
import pdb
# torch config
torch.set_default_dtype(torch.float32)
import torch
import numpy as np
import random
import os

# Set the random seed for Python, NumPy, and PyTorch
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda") 
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # if using multi-GPU.
# Configure PyTorch to use deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

adptor = JarvisAtomsAdaptor()

diagonal = [0, 4, 8]
off_diagonal = [1, 2, 3, 5, 6, 7]

def structure_to_graphs(
    df: pd.DataFrame,
    use_corrected_structure: bool = False,
    reduce_cell: bool = False,
    cutoff: float = 4.0,
    max_neighbors: int = 16
):
    def atoms_to_graph(p_input):
        """Convert structure dict to DGLGraph."""
        structure = adptor.get_atoms(p_input["structure"])
        return atoms2graphs(
            structure,
            cutoff=cutoff,
            max_neighbors=max_neighbors,
            reduce=reduce_cell,
            equivalent_atoms=p_input['equivalent_atoms'],
            use_canonize=True,
        )
    graphs = df["p_input"].parallel_apply(atoms_to_graph).values
    # graphs = df["p_input"].apply(atoms_to_graph).values
    return graphs


class PolynomialLRDecay(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, max_iters, start_lr, end_lr, power=1, last_epoch=-1):
        self.max_iters = max_iters
        self.start_lr = start_lr
        self.end_lr = end_lr
        self.power = power
        self.last_iter = 0  # Custom attribute to keep track of last iteration count
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [
            (self.start_lr - self.end_lr) * 
            ((1 - self.last_iter / self.max_iters) ** self.power) + self.end_lr 
            for base_lr in self.base_lrs
        ]

    def step(self, epoch=None):
        self.last_iter += 1  # Increment the last iteration count
        return super().step(epoch)


def get_pyg_dataset(data, target, reduce_cell=False):
    df_dataset = pd.DataFrame(data)
    g_dataset = structure_to_graphs(df_dataset, reduce_cell=reduce_cell)
    pyg_dataset = GraphDataset(df=df_dataset,graphs=g_dataset, target=target)
    return pyg_dataset

def train(model, args):
    # load the dataset
    if args.load_preprocessed:
        print("load preprocessed dataset ...")
    dataset_sym = get_dataset(dataset_name=args.target,use_corrected_structure=args.use_corrected_structure,load_preprocessed=args.load_preprocessed)
    # pdb.set_trace()
    # preprocess the dataset and random split
    id_train, id_val, id_test = get_id_train_val_test(
            total_size=len(dataset_sym),
            split_seed=args.split_seed,
            train_ratio=args.train_ratio,
            val_ratio=args.val_ratio,
            test_ratio=args.test_ratio,
            keep_data_order=False,
        )
    dataset_train = [dataset_sym[x] for x in id_train]
    dataset_val = [dataset_sym[x] for x in id_val]
    dataset_test = [dataset_sym[x] for x in id_test]
    
    pyg_dataset_train = get_pyg_dataset(dataset_train, args.target, args.reduce_cell)
    pyg_dataset_val = get_pyg_dataset(dataset_val, args.target, args.reduce_cell)
    pyg_dataset_test = get_pyg_dataset(dataset_test, args.target, args.reduce_cell)

    # form dataloaders
    collate_fn = pyg_dataset_train.collate
    train_loader = DataLoader(
        pyg_dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
    )

    val_loader = DataLoader(
        pyg_dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        pyg_dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
    )
    print("n_train:", len(train_loader.dataset))
    print("n_val:", len(val_loader.dataset))
    print("n_test:", len(test_loader.dataset))
    # set up training configs
    model.to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )
    steps_per_epoch = len(train_loader)
    total_iter = steps_per_epoch * args.epochs
    scheduler = PolynomialLRDecay(optimizer, max_iters=total_iter, start_lr=args.learning_rate, end_lr=0.00001, power=1)
    criteria = {
        "mse": nn.MSELoss(),
        "l1": nn.L1Loss(),
        "huber": nn.HuberLoss(),
    }
    criterion = criteria[args.loss]
 
    # training epoch
    # wandb.login()
    # wandb.init(project="crys")
    best_score = 10000
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{args.epochs}", unit='batch') as pbar:
            for data in train_loader:
                structure, mask, equality, labels = data
                structure, mask, equality, labels = structure.to(device), mask.to(device), equality.to(device), labels.to(device)
                optimizer.zero_grad()

                if args.model == "ceitnet":
                    outputs = model(structure, mask, None)
                    loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                pbar.set_postfix({'training_loss': running_loss / (pbar.n + 1)})
                pbar.update(1)
                scheduler.step()

                # Detach any large variables that are no longer needed
                outputs = outputs.detach()
                structure = structure.detach()
                mask = mask.detach()
                equality = equality.detach()
                labels = labels.detach()
                # Clear memory at the end of each batch
                del structure, mask, equality, labels, data
                # torch.cuda.empty_cache()

        for param in model.parameters():
            param.grad = None
        # torch.cuda.empty_cache()
        gc.collect()

        # Validation
        model.eval()
        running_loss = 0.0
        label_list = []
        output_list = []
        
        for data in val_loader:
            structure, mask, equality, labels = data
            structure = structure.to(device)
            mask = mask.to(device)
            labels = labels.to(device)
            equality = equality.to(device) if equality is not None else None
            if args.model == "ceitnet":
                outputs = model(structure, mask, None).detach()
            else:
                raise ValueError(f"Model {args.model} not supported")

            output_list.append(outputs.reshape(-1, 36))

            label_list.append(labels.reshape(-1, 36))

              
        outputs = torch.stack(output_list).reshape(-1, 36)
        labels = torch.stack(label_list).reshape(-1, 36)
        mae = abs(outputs - labels).mean(dim=-1).mean()
        
        if mae < best_score:
            best_score = mae
            torch.save(model.state_dict(), "runs/%s/model_best_%s.pt"%(args.name, args.model))

        print("Validation mae ", mae)

    return

def main():
    parser = argparse.ArgumentParser(description='Training script')

    # Define command-line arguments
    # training parameters
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size of training and evaluating')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-05, help='weight decay')
    parser.add_argument('--loss', type=str, default='huber', help='mse or l1 or huber')
    parser.add_argument('--model', type=str, default='ceitnet', help='ceitnet')
    parser.add_argument('--name', type=str, default='test', help='name of project for storage')
    parser.add_argument('--reduce_cell', type=bool, default=False, help='reduce the cell into irreducible atom sets')
    # dataset parameters
    parser.add_argument('--split_seed', type=int, default=32, help='the random seed of spliting data')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='training ratio used in data split')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='evaluate ratio used in data split')
    parser.add_argument('--test_ratio', type=float, default=0.1, help='test ratio used in data split')
    parser.add_argument('--target', type=str, default='elastic', help='dielectric, piezoelectric, or elastic')
    parser.add_argument('--threshold', type=float, default=100., help='threshold to remove samples')
    parser.add_argument('--use_corrected_structure', type=bool, default=True, help='correct input structure or not')
    parser.add_argument('--load_preprocessed', type=bool, default=True, help='load previous processed dataset')

    args = parser.parse_args()

    print('Training settings:')
    print(f'  Epochs: {args.epochs}')
    print(f'  Learning rate: {args.learning_rate}')
    print(args)
    torch.manual_seed(args.split_seed)
    torch.cuda.manual_seed_all(args.split_seed)
    # load the model
    if args.model == "ceitnet":
        model = CEITNet(args)

    if not os.path.exists('runs/' + args.name):
        # Create the directory
        os.makedirs('runs/' + args.name)

    train(model, args)

if __name__ == "__main__":
    main()