import argparse
import torch
from torch import nn
import numpy as np
from data import get_dataset
import pandas as pd
import pickle as pk
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from jarvis.core.atoms import Atoms
from torch.utils.data import DataLoader
from tqdm import tqdm
from e3nn.io import CartesianTensor
from pandarallel import pandarallel
from data import get_symmetry_dataset
pandarallel.initialize(progress_bar=False)

from graphs import atoms2graphs, GraphDataset
from utils import get_id_train_val_test
from ceitnet import CEITNet
import matplotlib.pyplot as plt
from e3nn import o3
import pdb
import json
from pathlib import Path
import glob
# torch config
torch.set_default_dtype(torch.float32)
import torch
import numpy as np
import random
import os

# Set the random seed for Python, NumPy, and PyTorch
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda")

torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # if using multi-GPU.
# Configure PyTorch to use deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

adptor = JarvisAtomsAdaptor()

diagonal = [0, 4, 8]
off_diagonal = [1, 2, 3, 5, 6, 7]

def structure_to_graphs(
    df: pd.DataFrame,
    use_corrected_structure: bool = False,
    reduce_cell: bool = False,
    cutoff: float = 4.0,
    max_neighbors: int = 64
):
    def atoms_to_graph(p_input):
        """Convert structure dict to DGLGraph."""
        structure = adptor.get_atoms(p_input["structure"])
        return atoms2graphs(
            structure,
            cutoff=cutoff,
            max_neighbors=max_neighbors,
            reduce=reduce_cell,
            equivalent_atoms=p_input['equivalent_atoms'],
            use_canonize=True,
        )
    graphs = df["p_input"].parallel_apply(atoms_to_graph).values
    # graphs = df["p_input"].apply(atoms_to_graph).values
    return graphs

class PolynomialLRDecay(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, max_iters, start_lr, end_lr, power=1, last_epoch=-1):
        self.max_iters = max_iters
        self.start_lr = start_lr
        self.end_lr = end_lr
        self.power = power
        self.last_iter = 0  # Custom attribute to keep track of last iteration count
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [
            (self.start_lr - self.end_lr) * 
            ((1 - self.last_iter / self.max_iters) ** self.power) + self.end_lr 
            for base_lr in self.base_lrs
        ]

    def step(self, epoch=None):
        self.last_iter += 1  # Increment the last iteration count
        return super().step(epoch)

def group_decay(model):
    """Omit weight decay from bias and batchnorm params."""
    decay, no_decay = [], []

    for name, p in model.named_parameters():
        if "bias" in name or "bn" in name or "norm" in name:
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay},
        {"params": no_decay, "weight_decay": 0},
    ]


def get_pyg_dataset(data, target, reduce_cell=False):
    df_dataset = pd.DataFrame(data)
    g_dataset = structure_to_graphs(df_dataset, reduce_cell=reduce_cell)
    pyg_dataset = GraphDataset(df=df_dataset,graphs=g_dataset, target=target)
    return pyg_dataset


def load_model_checkpoint(model: torch.nn.Module, args, *, device: torch.device) -> str:
    """Load model checkpoint weights into model and return the used ckpt path."""
    ckpt_path = args.ckpt_path
    state = torch.load(ckpt_path, map_location="cpu")
    # Support both raw state_dict and nested checkpoint dicts.
    if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
        state_dict = state["state_dict"]
    elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
        state_dict = state["model"]
    else:
        state_dict = state

    strict = bool(getattr(args, "ckpt_strict", True))
    model.load_state_dict(state_dict, strict=strict)
    model.to(device)
    model.eval()
    print(f"[CKPT] Loaded checkpoint: {ckpt_path} (strict={strict})")
    return ckpt_path

def train(model, args):
    # load the dataset
    if args.load_preprocessed:
        print("load preprocessed dataset ...")
    dataset_sym = get_dataset(dataset_name=args.target,use_corrected_structure=args.use_corrected_structure,load_preprocessed=args.load_preprocessed)
    # pdb.set_trace()
    # preprocess the dataset and random split
    id_train, id_val, id_test = get_id_train_val_test(
            total_size=len(dataset_sym),
            split_seed=args.split_seed,
            train_ratio=args.train_ratio,
            val_ratio=args.val_ratio,
            test_ratio=args.test_ratio,
            keep_data_order=False,
        )
    dataset_train = [dataset_sym[x] for x in id_train]
    dat_train = [itm for itm in dataset_train if (torch.tensor(itm['piezoelectric']) ** 2).sum() ** 0.5 > 1e-5]
    dataset_train = dat_train
    dataset_val = [dataset_sym[x] for x in id_val]
    dataset_test = [dataset_sym[x] for x in id_test]

    # raise Exception("Stop here")
    pyg_dataset_train = get_pyg_dataset(dataset_train, args.target, args.reduce_cell)
    pyg_dataset_val = get_pyg_dataset(dataset_val, args.target, args.reduce_cell)
    pyg_dataset_test = get_pyg_dataset(dataset_test, args.target, args.reduce_cell)

    # form dataloaders
    collate_fn = pyg_dataset_train.collate
    train_loader = DataLoader(
        pyg_dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
    )

    val_loader = DataLoader(
        pyg_dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        pyg_dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
    )
    print("n_train:", len(train_loader.dataset))
    print("n_val:", len(val_loader.dataset))
    print("n_test:", len(test_loader.dataset))
    # set up training configs
    model.to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )
    steps_per_epoch = len(train_loader)
    total_iter = steps_per_epoch * args.epochs
    scheduler = PolynomialLRDecay(optimizer, max_iters=total_iter, start_lr=args.learning_rate, end_lr=0.00001, power=1)
    criteria = {
        "mse": nn.MSELoss(),
        "l1": nn.L1Loss(),
        "huber": nn.HuberLoss(),
    }
    criterion = criteria[args.loss]
    MAE = nn.L1Loss()
 
    best_score = 10000
    for epoch in range(args.epochs):
        model.train()
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{args.epochs}", unit='batch') as pbar:
            for data in train_loader:
                structure, mask, equality, labels, group = data
                structure, mask, equality, labels = structure.to(device), mask.to(device), equality.to(device), labels.to(device)
                optimizer.zero_grad()

                if args.model == "ceitnet":
                    outputs = model(structure, mask, None)
                    loss = criterion(outputs, labels)
                else:
                    raise ValueError(f"Model {args.model} not supported")

                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                pbar.set_postfix({'training_loss': running_loss / (pbar.n + 1)})
                pbar.update(1)
                scheduler.step()

        # Validation
        model.eval()
        running_loss = 0.0
        label_list = []
        output_list = []
        
        for data in val_loader:
            structure, mask, _, labels, group = data
            structure, mask, labels = structure.to(device), mask.to(device), labels.to(device)
            if args.model == "ceitnet":
                outputs = model(structure, mask, None).detach()
            else:
                raise ValueError(f"Model {args.model} not supported")

            output_list.append(outputs.reshape(-1, 18))

            label_list.append(labels.reshape(-1, 18))

        
        outputs = torch.stack(output_list).reshape(-1, 18)
        labels = torch.stack(label_list).reshape(-1, 18)
        mae = abs(outputs - labels).mean(dim=-1).mean()
        
        if mae < best_score and epoch > 20:
            best_score = mae
            torch.save(model.state_dict(), "runs/%s/model_best_%s.pt"%(args.name, args.model))

        print("Validation mae ", mae)

    return


def main():
    parser = argparse.ArgumentParser(description='Training script')

    # Define command-line arguments
    # training parameters
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size of training and evaluating')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-05, help='weight decay')
    parser.add_argument('--loss', type=str, default='huber', help='mse or l1 or huber')
    parser.add_argument('--model', type=str, default='ceitnet', help='ceitnet')
    parser.add_argument('--name', type=str, default='test', help='name of project for storage')
    parser.add_argument('--reduce_cell', type=bool, default=False, help='reduce the cell into irreducible atom sets')
    # dataset parameters
    parser.add_argument('--split_seed', type=int, default=32, help='the random seed of spliting data')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='training ratio used in data split')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='evaluate ratio used in data split')
    parser.add_argument('--test_ratio', type=float, default=0.1, help='test ratio used in data split')
    parser.add_argument('--target', type=str, default='piezoelectric', help='dielectric, piezoelectric, or elastic')
    parser.add_argument('--threshold', type=float, default=100., help='threshold to remove samples')
    parser.add_argument('--use_corrected_structure', type=bool, default=True, help='correct input structure or not')
    parser.add_argument('--load_preprocessed', type=bool, default=True, help='load previous processed dataset')
    # checkpoint parameters (optional)
    parser.add_argument('--ckpt_path', type=str, default='', help='Path to a checkpoint .pt file (overrides ckpt_kind)')

    args = parser.parse_args()

    print('Training settings:')
    print(f'  Epochs: {args.epochs}')
    print(f'  Learning rate: {args.learning_rate}')
    print(args)
    torch.manual_seed(args.split_seed)
    torch.cuda.manual_seed_all(args.split_seed)
    # load the model
    if args.model == "ceitnet":
        model = CEITNet(args)
    else:
        raise ValueError(f"Model {args.model} not supported")

    if not os.path.exists('runs/' + args.name):
        # Create the directory
        os.makedirs('runs/' + args.name)

    train(model, args)

if __name__ == "__main__":
    main()
