import argparse
import torch
from torch import nn
import numpy as np
from data import get_dataset
import pandas as pd
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from torch.utils.data import DataLoader
from tqdm import tqdm
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=False)

from graphs import atoms2graphs, GraphDataset
from utils import get_id_train_val_test
from ceitnet import CEITNet
torch.set_default_dtype(torch.float32)
import random
import os

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

# Set the random seed for Python, NumPy, and PyTorch
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda")

torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # if using multi-GPU.
# Configure PyTorch to use deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

adptor = JarvisAtomsAdaptor()

diagonal = [0, 4, 8]
off_diagonal = [1, 2, 3, 5, 6, 7]

def structure_to_graphs(
    df: pd.DataFrame,
    use_corrected_structure: bool = False,
    reduce_cell: bool = False,
    cutoff: float = 4.0,
    max_neighbors: int = 16
):
    def atoms_to_graph(p_input):
        """Convert structure dict to DGLGraph."""
        structure = adptor.get_atoms(p_input["structure"])
        return atoms2graphs(
            structure,
            cutoff=cutoff,
            max_neighbors=max_neighbors,
            reduce=reduce_cell,
            equivalent_atoms=p_input['equivalent_atoms'],
            use_canonize=True,
        )
    graphs = df["p_input"].parallel_apply(atoms_to_graph).values
    # graphs = df["p_input"].apply(atoms_to_graph).values
    return graphs

class PolynomialLRDecay(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, max_iters, start_lr, end_lr, power=1, last_epoch=-1):
        self.max_iters = max_iters
        self.start_lr = start_lr
        self.end_lr = end_lr
        self.power = power
        self.last_iter = 0  # Custom attribute to keep track of last iteration count
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [
            (self.start_lr - self.end_lr) * 
            ((1 - self.last_iter / self.max_iters) ** self.power) + self.end_lr 
            for base_lr in self.base_lrs
        ]

    def step(self, epoch=None):
        self.last_iter += 1  # Increment the last iteration count
        return super().step(epoch)

def get_pyg_dataset(data, target, reduce_cell=False):
    df_dataset = pd.DataFrame(data)
    g_dataset = structure_to_graphs(df_dataset, reduce_cell=reduce_cell)
    pyg_dataset = GraphDataset(df=df_dataset,graphs=g_dataset, target=target)
    return pyg_dataset

def train(model, args):
    # load the dataset
    if args.load_preprocessed:
        print("load preprocessed dataset ...")
    dataset_sym = get_dataset(dataset_name=args.target,use_corrected_structure=args.use_corrected_structure,load_preprocessed=args.load_preprocessed)
    # pdb.set_trace()
    # preprocess the dataset and random split
    id_train, id_val, id_test = get_id_train_val_test(
            total_size=len(dataset_sym),
            split_seed=args.split_seed,
            train_ratio=args.train_ratio,
            val_ratio=args.val_ratio,
            test_ratio=args.test_ratio,
            keep_data_order=False,
        )
    dataset_train = [dataset_sym[x] for x in id_train]
    dataset_val = [dataset_sym[x] for x in id_val]
    dataset_test = [dataset_sym[x] for x in id_test]
    
    pyg_dataset_train = get_pyg_dataset(dataset_train, args.target, args.reduce_cell)
    pyg_dataset_val = get_pyg_dataset(dataset_val, args.target, args.reduce_cell)
    pyg_dataset_test = get_pyg_dataset(dataset_test, args.target, args.reduce_cell)

    # form dataloaders
    collate_fn = pyg_dataset_train.collate
    train_loader = DataLoader(
        pyg_dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
    )

    val_loader = DataLoader(
        pyg_dataset_val,
        batch_size=64,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        pyg_dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
    )
    print("n_train:", len(train_loader.dataset))
    print("n_val:", len(val_loader.dataset))
    print("n_test:", len(test_loader.dataset))
    # set up training configs
    model.to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )
    steps_per_epoch = len(train_loader)
    total_iter = steps_per_epoch * args.epochs
    scheduler = PolynomialLRDecay(optimizer, max_iters=total_iter, start_lr=args.learning_rate, end_lr=0.00001, power=1)

    criteria = {
        "mse": nn.MSELoss(),
        "l1": nn.L1Loss(),
        "huber": nn.HuberLoss(),
    }
    criterion = criteria[args.loss]
 
    # training epoch
    best_score = 10000
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{args.epochs}", unit='batch') as pbar:
            for data in train_loader:
                structure, mask, equality, labels, rot_list = data
                structure, mask, equality, labels = structure.to(device), mask.to(device), equality.to(device), labels.to(device)
                optimizer.zero_grad()

                if args.model == "gmtnet":
                    outputs = model(structure, mask, equality)
                    loss = criterion(outputs, labels)
                elif args.model == "megnet":
                    outputs = model(structure).view(-1, 3, 3)
                    # ablation for frame average
                    out_list = []
                    for bi in range(len(rot_list)):
                        out = outputs[bi]
                        R = rot_list[bi].to(device)
                        RT = R.transpose(1, 2)
                        out = out.repeat(R.shape[0], 1, 1)
                        RM = torch.matmul(R, out)
                        res = torch.matmul(RM, RT).mean(dim=0)
                        out_list.append(res)
                    loss = criterion(torch.stack(out_list), labels)
                elif args.model == "mace" or args.model == "ecomformer":
                    outputs = model(structure).view(-1, 3, 3)
                    loss = criterion(outputs, labels)
                else:
                    outputs = model(structure)
                    # ablation for frame average
                    out_list = []
                    for bi in range(len(rot_list)):
                        out = outputs[bi]
                        R = rot_list[bi].to(device)
                        RT = R.transpose(1, 2)
                        out = out.repeat(R.shape[0], 1, 1)
                        RM = torch.matmul(R, out)
                        res = torch.matmul(RM, RT).mean(dim=0)
                        out_list.append(res)
                    loss = criterion(torch.stack(out_list), labels)

                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                pbar.set_postfix({'training_loss': running_loss / (pbar.n + 1)})
                pbar.update(1)
                scheduler.step()


        # Validation
        model.eval()
        running_loss = 0.0
        label_list = []
        output_list = []
        
        for data in val_loader:
            structure, mask, equality, labels, rot_list = data
            structure, mask, equality, labels = structure.to(device), mask.to(device), equality.to(device), labels.to(device)
            if args.model == "ceitnet":
                outputs = model(structure, mask, equality).detach()
            else:
                pass

            output_list.append(outputs.reshape(-1, 9))

            label_list.append(labels.reshape(-1, 9))

        
        outputs = torch.stack(output_list).reshape(-1, 9)
        labels = torch.stack(label_list).reshape(-1, 9)
        mae = abs(outputs - labels).mean(dim=-1).mean()
        
        if mae < best_score:
            best_score = mae
            torch.save(model.state_dict(), "runs/%s/model_best_%s.pt"%(args.name, args.model))
        print("Validation mae ", mae)

    return


def main():
    parser = argparse.ArgumentParser(description='Training script')

    # Define command-line arguments
    # training parameters
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size of training and evaluating')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-05, help='weight decay')
    parser.add_argument('--loss', type=str, default='huber', help='mse or l1 or huber')
    parser.add_argument('--model', type=str, default='ceitnet', help='gmtnet, ecomformer or megnet')
    parser.add_argument('--project', type=str, default='test', help='name of project for wandb visualization')
    parser.add_argument('--name', type=str, default='test', help='name of project for storage')
    parser.add_argument('--reduce_cell', type=bool, default=False, help='reduce the cell into irreducible atom sets, not used')
    # dataset parameters
    parser.add_argument('--split_seed', type=int, default=32, help='the random seed of spliting data')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='training ratio used in data split')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='evaluate ratio used in data split')
    parser.add_argument('--test_ratio', type=float, default=0.1, help='test ratio used in data split')
    parser.add_argument('--target', type=str, default='dielectric', help='dielectric, piezoelectric, or elastic')
    parser.add_argument('--test_augment', type=str, default='None', help='None, XZ_exchange, Xrotate, Yrotate, Zrotate')
    parser.add_argument('--threshold', type=float, default=100., help='threshold to remove samples')
    parser.add_argument('--use_corrected_structure', type=bool, default=True, help='correct input structure or not')
    parser.add_argument('--load_preprocessed', type=bool, default=True, help='load previous processed dataset')

    args = parser.parse_args()

    print('Training settings:')
    print(f'  Epochs: {args.epochs}')
    print(f'  Learning rate: {args.learning_rate}')
    print(args)
    torch.manual_seed(args.split_seed)
    torch.cuda.manual_seed_all(args.split_seed)
    # load the model
    if args.model == "ceitnet":
        model = CEITNet(args)
    else:
        pass

    if not os.path.exists('runs/' + args.name):
        # Create the directory
        os.makedirs('runs/' + args.name)

    train(model, args)

if __name__ == "__main__":
    main()
