import shutil
import logging
import os
import torch
import numpy as np
import math

from tqdm import tqdm

def get_logger(log_dir, verbose = 'info'):
    if os.path.exists(log_dir):
        shutil.rmtree(log_dir) # remove old log directory
    os.makedirs(log_dir) # create new log directory

    # setup logger
    level = getattr(logging, verbose.upper(), None)  # get logging level
    if not isinstance(level, int):
        raise ValueError('level {} not supported'.format(verbose))
    handler1 = logging.StreamHandler() # print to console
    handler2 = logging.FileHandler(os.path.join(log_dir, 'log.txt')) # print to file

    formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s',
                                   datefmt= '%m-%d %H:%M:%S') 

    handler1.setFormatter(formatter) # set format for console
    handler2.setFormatter(formatter) # set format for file
    logger = logging.getLogger() # get root logger
    logger.handlers.clear()  # clear existing handlers
    logger.propagate = False  # prevent propagation to parent logger

    logger.addHandler(handler1) # add console handler
    logger.addHandler(handler2) # add fcvile handler
    logger.setLevel(level) # set logging level
    return logger  # return logger instance

def set_seed_everywhere(seed):
    torch.manual_seed(seed)  # set seed for CPU
    np.random.seed(seed)  # set seed for numpy
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = True

def save_model(workdir, network, name):
    save_path = os.path.join(workdir, name)
    torch.save(network, save_path)

def load_model(model_path):
    if os.path.exists(model_path):
        logging.info(f"Loading model from {model_path}")
        return torch.load(model_path, map_location=torch.device('cpu'))
    else:
        logging.info(f'The trained model path {model_path} does not exists.')
        return None

def uniform_triangles_sample(triangles):
    tri_origins = triangles[:, 0]
    tri_vectors = triangles[:, 1:].copy()
    tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))

    # randomly generate two 0-1 scalar components to multiply edge vectors by
    random_lengths = np.random.random((len(tri_vectors), 2, 1))

    # points will be distributed on a quadrilateral if we use 2 0-1 samples
    # if the two scalar components sum less than 1.0 the point will be
    # inside the triangle, so we find vectors longer than 1.0 and transform them to be inside the triangle
    random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
    random_lengths[random_test] -= 1.0
    random_lengths = np.abs(random_lengths)

    # multiply triangle edge vectors by the random lengths and sum
    sample_vector = (tri_vectors * random_lengths).sum(axis=1)
    samples = sample_vector + tri_origins
    return samples


def run_func_in_batches(funcs, x, max_batch_size, out_dim):
    k = int(np.ceil(x.shape[0] / max_batch_size))  # number of batches
    if out_dim is None:
        out = torch.empty(x.shape[0])
    else:
        out = torch.empty(x.shape[0], out_dim)
    
    for i in tqdm(range(k)):
        x_i = x[i * max_batch_size: (i + 1) * max_batch_size]
        out[i * max_batch_size: (i + 1) * max_batch_size] = funcs(x_i).detach().view(x_i.shape[0], out_dim)  # apply function to batch
    torch.cuda.empty_cache()  # clear GPU memory
    return out  # return results

def split_dataset(data_ori, data_seed):
    total_size = data_ori.shape[0]
    test_size = int(total_size * 0.1)
    val_size = int(total_size * 0.1)
    training_size = total_size - test_size - val_size

    train_data, val_data, test_data = torch.utils.data.random_split(data_ori, [training_size, val_size, test_size],
                                                                    generator=torch.Generator().manual_seed(data_seed))
    
    training_set = data_ori[train_data.indices]
    test_set = data_ori[test_data.indices]
    val_set = data_ori[val_data.indices]
    logging.info(f"size of training set: {training_set.shape[0]}, size of validation set: {val_set.shape[0]}, size of test set: {test_set.shape[0]}.")
    return training_set, test_set, val_set

def check_memory(data=None, keep_quiet=False):
    if data is not None:
        memory_bytes = data.element_size() * data.nelement()
        if not keep_quiet:
            logging.info(f"The data (shape {data.shape}) occupy {memory_bytes / (1024 ** 3):.3f} G of memory on {data.device}.")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if not keep_quiet:
            logging.info (f"Memory allocated: {torch.cuda.memory_allocated() / (1024 ** 3):.3f} G")

import math

import math

def get_temperature(epoch, max_epochs, T_min=0.5, T_max=1.0, T_min_cutoff=0.3, T_max_cutoff=0.5, mode="cosine"):
    c1, c2 = int(T_min_cutoff * max_epochs), int(T_max_cutoff * max_epochs)
    if epoch < c1: return T_min
    if epoch >= c2: return T_max
    p = (epoch - c1) / max(1, c2 - c1)
    s = p if mode == "linear" else 0.5 * (1 - math.cos(math.pi * p))
    return T_min + (T_max - T_min) * s



class ExponentialMovingAverage:
    def __init__(self, parameters, decay, use_num_updates=True):
        if decay < 0.0 or decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")
        self.decay = decay
        self.num_updates = 0 if use_num_updates else None # count the number of updates
        self.shadow_params = [p.clone().detach() for p in parameters] # create a copy of the parameters
        self.collected_params = []

    def update(self, parameters):
        """
        Update currently maintained parameters.
        """
        decay = self.decay
        if self.num_updates is not None:
            self.num_updates += 1
            decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 

        one_minus_decay = 1.0 - decay
        with torch.no_grad():
            parameters = [p for p in parameters if p.requires_grad]
            for s_param, param in zip(self.shadow_params, parameters):
                s_param.sub_(one_minus_decay * (s_param - param))
                # s_param = s_param * decay + param * (1 - decay)
        
    def copy_to(self, parameters):
        """
        Copy currently maintained parameters to the given model parameters.
        """
        parameters = [p for p in parameters if p.requires_grad]
        for s_param, param in zip(self.shadow_params, parameters):
            if param.requires_grad:
                param.data.copy_(s_param.data) # copy the shadow parameters to the model parameters

    def store(self, parameters):
        """
        Save the current parameters for restoring later.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)
    
    def state_dict(self):
        return dict(decay=self.decay, num_updates=self.num_updates,
                    shadow_params=self.shadow_params)

    def load_state_dict(self, state_dict):
        self.decay = state_dict["decay"]
        self.num_updates = state_dict["num_updates"]
        self.shadow_params = state_dict["shadow_params"]


@torch.no_grad()
def Kabsch(x, xref):
    """
    align states by translation and rotation.
    This method implements the Kabsch algorithm.
    """
    assert isinstance(x, torch.Tensor), 'Input x is not a torch tensor'

    # device = x.device
    xref = xref.to(x)
    align_atom_indices = list(range(0, xref.shape[0]))

    xref = xref - torch.mean(xref[align_atom_indices, :], 0, True)
    b = torch.mean(x[:, align_atom_indices, :], 1, True)
    x_notran = x[:, align_atom_indices, :] - b

    if not torch.isfinite(x_notran).all():
        print("Some point go to the infinity!!")

    xtmp = x_notran.permute((0, 2, 1))

    prod = torch.matmul(xtmp, xref)  # batched matrix multiplication, output dimension: traj_length x 3 x 3
    # u, s, vh = torch.linalg.svd(prod)

    try:
        u, s, vh = torch.linalg.svd(prod)
    except torch._C._LinAlgError as e:
        error_idx = torch.where(~torch.isfinite(prod))[0].unique()
        print(f"Index {error_idx} of prod is NAN! It was changed into zero before SVD.")
        prod = torch.nan_to_num(prod, nan=0.0, posinf=0.0, neginf=0.0)
        u, s, vh = torch.linalg.svd(prod)

    diag_mat = torch.diag(torch.ones(3)).unsqueeze(0).repeat(x.size(0), 1, 1).to(x.device, dtype=u.dtype)
    sign_vec = torch.sign(torch.linalg.det(torch.matmul(u, vh))).detach()
    diag_mat[:, 2, 2] = sign_vec

    R = torch.bmm(torch.bmm(u, diag_mat), vh)
    return R.transpose(1, 2), b


def get_RMSD(xvec, xref):

    R, b = Kabsch(xvec, xref)
    b0 = torch.mean(xref, 0, True)
    error = xvec - b - torch.matmul(xref - b0, R)
    return torch.sqrt(torch.sum(error ** 2, dim=(1, 2))/xref.shape[0])

