import numpy as np
import torch
import random
import logging
from time import time
from pytorchtools import EarlyStopping
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from loss import MaskedMSELoss

# From https://github.com/dr-aheydari/SoftAdapt/tree/main
from weighting.softadapt import LossWeightedSoftAdapt
from weighting.softadapt import NormalizedSoftAdapt
from weighting.softadapt import SoftAdapt

def init_dl_program(
    device_name,
    seed=None,
    use_cudnn=True,
    deterministic=True,
    benchmark=False,
    use_tf32=False,
    max_threads=None
):
    import torch
    if max_threads is not None:
        torch.set_num_threads(max_threads)  # intraop
        if torch.get_num_interop_threads() != max_threads:
            torch.set_num_interop_threads(max_threads)  # interop
        try:
            import mkl
        except:
            pass
        else:
            mkl.set_num_threads(max_threads)
        
    if seed is not None:
        random.seed(seed)
        seed += 1
        np.random.seed(seed)
        seed += 1
        torch.manual_seed(seed)
        
    if isinstance(device_name, (str, int)):
        device_name = [device_name]
    
    devices = []
    for t in reversed(device_name):
        t_device = torch.device(t)
        devices.append(t_device)
        if t_device.type == 'cuda':
            assert torch.cuda.is_available()
            torch.cuda.set_device(t_device)
            if seed is not None:
                seed += 1
                torch.cuda.manual_seed(seed)
    devices.reverse()
    torch.backends.cudnn.enabled = use_cudnn
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = benchmark
    
    if hasattr(torch.backends.cudnn, 'allow_tf32'):
        torch.backends.cudnn.allow_tf32 = use_tf32
        torch.backends.cuda.matmul.allow_tf32 = use_tf32
        
    return devices if len(devices) > 1 else devices[0]

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed

def semi_adaptive_train(labeledloader, allloader, valloader, testloader,
                model, classifier, epochs, lr, device, save_path, patience=30):

    params = [{'params': model.parameters()},
          {'params': classifier.parameters()}]
    
    BCE = torch.nn.BCELoss()
    CE = torch.nn.CrossEntropyLoss()
    MSE = MaskedMSELoss()

    unlabel_adapt_object = LossWeightedSoftAdapt(beta=0.1)
    epochs_to_make_updates = 3
    
    values_of_unlabel_mse_loss = []
    values_of_unlabel_bce_loss = []

    unlabel_adapt_weights = torch.tensor([1., 1.])

    early_stopping = EarlyStopping(patience, verbose=False,
                                   checkpoint_pth=save_path[1])

    optimizer = torch.optim.AdamW(params, lr=lr)

    torch.save({'model': model.state_dict(), 
                'classifier': classifier.state_dict()}, save_path[0])

    best_score = 0
    best_epoch = 0

    times = []
    for epoch in range(epochs):
        start = time()
        model.train()
        classifier.train()

        acc_epoch = 0
        acc_epoch_unlabel = 0
        loss_epoch_label = 0
        loss_epoch_unlabel = 0
        
        mask_ratio = random.randint(1, 9) * 0.1
        
        for i, (x, y, _) in enumerate(labeledloader):
            optimizer.zero_grad()    

            x, y = x.to(device), y.type(torch.LongTensor).to(device)
            
            latent, output, padding_masks, adj, output_adj = model(x, mask_ratio=0.)
            cls_pred = classifier(latent)
            ce_loss = CE(cls_pred, y)

            loss_label = ce_loss
            loss_epoch_label += loss_label.item()

            loss_label.backward()
            optimizer.step()

            prediction = cls_pred.argmax(1)
            correct = prediction.eq(y.view_as(prediction)).sum()
            accuracy = (100. * correct / len(y))
            acc_epoch += accuracy.item()

        mse_losses = []
        bce_losses = []
        for i, (ux, uy, _) in enumerate(allloader):
            ux = ux.to(device)
            uy = uy.type(torch.LongTensor).to(device) # Just for monitoring

            optimizer.zero_grad()

            latent, output, padding_masks, adj, output_adj = model(ux, mask_ratio)
            mse_loss = MSE(output, ux, padding_masks.unsqueeze(-1))
            bce_loss = BCE(output_adj, adj)

            mse_losses.append(mse_loss)
            bce_losses.append(bce_loss)

            if i == len(allloader)-1:
                values_of_unlabel_mse_loss.append(torch.mean(torch.tensor(mse_losses)))
                values_of_unlabel_bce_loss.append(torch.mean(torch.tensor(bce_losses)))

                if epoch % epochs_to_make_updates == 0 and epoch != 0:
                    unlabel_adapt_weights = unlabel_adapt_object.get_component_weights(torch.tensor(values_of_unlabel_mse_loss),
                                                                                torch.tensor(values_of_unlabel_bce_loss),
                                                                                verbose=False,)
                    values_of_unlabel_mse_loss = []
                    values_of_unlabel_bce_loss = []

            loss_unlabel = unlabel_adapt_weights[0] * mse_loss + unlabel_adapt_weights[1] * bce_loss
            loss_unlabel.backward()
            optimizer.step()
            
            u_pred = classifier(latent)
            prediction = u_pred.argmax(1)
            correct = prediction.eq(uy.view_as(prediction)).sum()
            accuracy = (100. * correct / len(uy))
            acc_epoch_unlabel += accuracy.item()

            loss_epoch_unlabel += loss_unlabel.item()

        acc_epoch /= len(labeledloader)
        acc_epoch_unlabel /= len(allloader)
        loss_epoch_label /= len(labeledloader)
        loss_epoch_unlabel /= len(allloader)
        
        end = time()
        times.append(end-start)

        acc_vals = list()
        acc_tests = list()
        model.eval()
        classifier.eval()
        with torch.no_grad():
            for _, (val_x, val_y, _) in enumerate(valloader):
                val_x = val_x.to(device)
                val_y = val_y.type(torch.LongTensor).to(device)

                latent, output, padding_masks, adj, output_adj = model(val_x, mask_ratio)
                val_pred = classifier(latent)

                prediction = val_pred.argmax(1)
                correct = prediction.eq(val_y.view_as(prediction)).sum()
                accuracy = (100. * correct / len(val_y))
                acc_vals.append(accuracy.item())
            
            val_acc = sum(acc_vals) / len(acc_vals)

            if val_acc > best_score:
                best_score = val_acc
                best_epoch = epoch + 1
                for _, (test_x, test_y, _) in enumerate(testloader):
                    test_x = test_x.to(device)
                    test_y = test_y.type(torch.LongTensor).to(device)

                    latent, output, padding_masks, _, _ = model(test_x, mask_ratio)
                    test_pred = classifier(latent)
                    
                    prediction = test_pred.argmax(1)
                    correct = prediction.eq(test_y.view_as(prediction)).sum()
                    accuracy = (100. * correct / len(test_y))

                    acc_tests.append(accuracy.item())

                test_acc = sum(acc_tests) / len(acc_tests)
        
        early_stopping(val_acc, model, classifier)
        if early_stopping.early_stop:
            print('Early Stopping.')
            processing_time = sum(times) / (epoch + 1)
            break
    processing_time = sum(times) / (epoch + 1)
    torch.save({'model': model.state_dict(), 
                'classifier': classifier.state_dict()}, save_path[2])

    return test_acc, best_epoch, processing_time

def test(testloader, model, classifier, device):
    model.eval()
    classifier.eval()
    with torch.no_grad():
        accs = []
        recs = []
        pres = []
        f1s = []
        for x, y, _ in testloader:
            x = x.to(device)
            latent, output, padding_masks, adj, output_adj = model(x)
            pred = classifier(latent)

            prediction = pred.argmax(1)
            prediction = prediction.detach().cpu().numpy()
            y = y.detach().cpu().numpy()

            accs.append(accuracy_score(y, prediction))
            recs.append(recall_score(y, prediction, average='macro'))
            pres.append(precision_score(y, prediction, average='macro'))
            f1s.append(f1_score(y, prediction, average='macro'))

    return sum(accs)/len(accs), sum(recs)/len(recs), sum(pres)/len(pres), sum(f1s)/len(f1s)

def setup_logger(name, log_file, level=logging.INFO):
    """To setup as many loggers as you want"""
    formatter = logging.Formatter('%(asctime)s %(message)s')
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    return logger