import shutil
import logging
import os
import torch
import numpy as np
from tqdm import tqdm



def get_logger(log_dir, verbose='info'):
    if os.path.exists(log_dir):
        shutil.rmtree(log_dir)
    os.makedirs(log_dir)

    # setup logger
    level = getattr(logging, verbose.upper(), None)
    if not isinstance(level, int):
        raise ValueError('level {} not supported'.format(verbose))

    handler1 = logging.StreamHandler()
    handler2 = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
    formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
    handler1.setFormatter(formatter)
    handler2.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    logger.setLevel(level)

    return logger


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = True


def save_model(workdir, network, name):
    save_path = os.path.join(workdir, name)
    torch.save(network, save_path)


def load_model(model_path):
    if os.path.exists(model_path):
        logging.info(f"Loading model from {model_path}")
        return torch.load(model_path, map_location=torch.device('cpu'))
    else:
        logging.info(f'The trained model path {model_path} does not exists.')
        return None


def split_dataset(data_ori, data_seed):
    total_size = data_ori.shape[0]
    test_size = int(total_size * 0.1)
    val_size = int(total_size * 0.1)
    training_size = total_size - test_size - val_size

    train_data, val_data, test_data = torch.utils.data.random_split(data_ori, [training_size, val_size, test_size], 
                                                          generator=torch.Generator().manual_seed(data_seed))
    
    training_set = data_ori[train_data.indices]
    test_set = data_ori[test_data.indices]
    val_set = data_ori[val_data.indices]
    logging.info(f"size of the training set: {training_set.shape[0]}, size of the test set: {test_set.shape[0]}, size of the validation set: {val_set.shape[0]}.")
    return training_set, test_set, val_set


def get_grid_new(points, resolution=200):
    input_min = torch.min(points, dim=0)[0].cpu().numpy()
    input_max = torch.max(points, dim=0)[0].cpu().numpy()

    bounding_box = input_max - input_min
    grad_min = input_min - bounding_box * 0.05
    grad_max = input_max + bounding_box * 0.05

    x = np.linspace(grad_min[0],  grad_max[0], resolution)
    y = np.linspace(grad_min[1], grad_max[1], resolution)
    z = np.linspace(grad_min[2], grad_max[2], resolution)

    xx, yy, zz = np.meshgrid(x, y, z)
    grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)
    if torch.cuda.is_available(): grid_points = grid_points.cuda()
    return {"grid_points": grid_points, "xyz": [x, y, z]}


def run_func_in_batches(func, x, max_batch_size, out_dim):
    k = int(np.ceil(x.shape[0] / max_batch_size))
    if out_dim is None:
        out = torch.empty(x.shape[0])
    else:
        out = torch.empty(x.shape[0], out_dim)
    for i in tqdm(range(k)):
        x_i = x[i * max_batch_size: (i+1) * max_batch_size]
        out[i * max_batch_size: (i+1) * max_batch_size] = func(x_i).detach().view(x_i.shape[0], out_dim)
    torch.cuda.empty_cache()
    return out


class ExponentialMovingAverage:
    def __init__(self, parameters, decay, use_num_updates=True):
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')
        self.decay = decay
        self.num_updates = 0 if use_num_updates else None
        self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
        self.collected_params = []

    def update(self, parameters):
        """
        Update currently maintained parameters.
        """
        decay = self.decay
        if self.num_updates is not None:
            self.num_updates += 1
            decay = min(decay, (1 + self.num_updates) /
                        (10 + self.num_updates))
        one_minus_decay = 1.0 - decay
        with torch.no_grad():
            parameters = [p for p in parameters if p.requires_grad]
            for s_param, param in zip(self.shadow_params, parameters):
                s_param.sub_(one_minus_decay * (s_param - param))

    def copy_to(self, parameters):
        """
        Copy current parameters into given collection of parameters.
        """
        parameters = [p for p in parameters if p.requires_grad]
        for s_param, param in zip(self.shadow_params, parameters):
            if param.requires_grad:
                param.data.copy_(s_param.data)

    def store(self, parameters):
        """
        Save the current parameters for restoring later.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

    def state_dict(self):
        return dict(decay=self.decay, num_updates=self.num_updates,
                    shadow_params=self.shadow_params)
    
    def load_state_dict(self, state_dict):
        self.decay = state_dict['decay']
        self.num_updates = state_dict['num_updates']
        self.shadow_params = state_dict['shadow_params']


def cal_trace(mat):
    return mat.diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)


def uniform_triangles_sample(triangles):
    tri_origins = triangles[:, 0]
    tri_vectors = triangles[:, 1:].copy()
    tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))

    # randomly generate two 0-1 scalar components to multiply edge vectors by
    random_lengths = np.random.random((len(tri_vectors), 2, 1))

    # points will be distributed on a quadrilateral if we use 2 0-1 samples
    # if the two scalar components sum less than 1.0 the point will be
    # inside the triangle, so we find vectors longer than 1.0 and transform them to be inside the triangle
    random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
    random_lengths[random_test] -= 1.0
    random_lengths = np.abs(random_lengths)

    # multiply triangle edge vectors by the random lengths and sum
    sample_vector = (tri_vectors * random_lengths).sum(axis=1)
    samples = sample_vector + tri_origins
    return samples


def check_memory(data=None, keep_quiet=False):
    if data is not None:
        memory_bytes = data.element_size() * data.nelement()
        if not keep_quiet:
            logging.info(f"The data (shape {data.shape}) occupy {memory_bytes / (1024 ** 3):.3f} G of memory on {data.device}.")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if not keep_quiet:
            logging.info(f"Memory allocated: {torch.cuda.memory_allocated() / (1024 ** 3):.3f}G")


@torch.no_grad()
def Kabsch(x, xref):
    """
    align states by translation and rotation.
    This method implements the Kabsch algorithm.
    """
    assert isinstance(x, torch.Tensor), 'Input x is not a torch tensor'

    # device = x.device
    xref = xref.to(x)
    align_atom_indices = list(range(0, xref.shape[0]))

    xref = xref - torch.mean(xref[align_atom_indices, :], 0, True)
    b = torch.mean(x[:, align_atom_indices, :], 1, True)
    x_notran = x[:, align_atom_indices, :] - b

    if not torch.isfinite(x_notran).all():
        print("Some point go to the infinity!!")

    xtmp = x_notran.permute((0, 2, 1))

    prod = torch.matmul(xtmp, xref)  # batched matrix multiplication, output dimension: traj_length x 3 x 3
    # u, s, vh = torch.linalg.svd(prod)

    try:
        u, s, vh = torch.linalg.svd(prod)
    except torch._C._LinAlgError as e:
        error_idx = torch.where(~torch.isfinite(prod))[0].unique()
        print(f"Index {error_idx} of prod is NAN! It was changed into zero before SVD.")
        prod = torch.nan_to_num(prod, nan=0.0, posinf=0.0, neginf=0.0)
        u, s, vh = torch.linalg.svd(prod)

    diag_mat = torch.diag(torch.ones(3)).unsqueeze(0).repeat(x.size(0), 1, 1).to(x.device, dtype=u.dtype)
    sign_vec = torch.sign(torch.linalg.det(torch.matmul(u, vh))).detach()
    diag_mat[:, 2, 2] = sign_vec

    R = torch.bmm(torch.bmm(u, diag_mat), vh)
    return R.transpose(1, 2), b


def get_RMSD(xvec, xref):

    R, b = Kabsch(xvec, xref)
    b0 = torch.mean(xref, 0, True)
    error = xvec - b - torch.matmul(xref - b0, R)
    return torch.sqrt(torch.sum(error ** 2, dim=(1, 2))/xref.shape[0])




