import time
import wandb
import random
import numpy as np
import os.path as osp
from tqdm import tqdm, trange
from pprint import pformat
from omegaconf import OmegaConf

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torch_geometric.data import Batch, Data
from torch_geometric.utils import (
    degree, add_remaining_self_loops,
    to_dense_adj, dense_to_sparse
)

from dig.auggraph.method.SMixup.model.GraphMatching import GraphMatching
from dig.auggraph.method.SMixup.utils.sinkhorn import Sinkhorn
from dig.auggraph.method.SMixup.utils.utils import triplet_loss
from dig.auggraph.dataset.aug_dataset import TripleSet

from parsers import Parser, get_config
from dataset.loader import get_batched_datalist, MultiEpochsPYGDataLoader
from dataset.misc import batched_to_list
from dataset.property import get_properties
from utils import set_seed
from utils.loader import (
    load_data, 
    load_sampler,
    load_downstream_model,
    load_diffusion_guidance_optim,
)

torch.set_num_threads(2)


class MultiClassClassificationLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, targets, outputs):
        loss = self.loss(outputs, targets)
        accuracy = self._calculate_accuracy(outputs, targets)
        return loss, accuracy

    def _get_correct(self, outputs):
        return torch.argmax(outputs, dim=1)

    def _calculate_accuracy(self, outputs, targets):
        outputs = self._get_correct(outputs)
        targets = self._get_correct(targets)
        return 100. * (outputs == targets).sum().float() / targets.size(0)


class MultiClassClassificationLossWithMixup(MultiClassClassificationLoss):
    def __init__(self):
        super().__init__()

    def forward(self, targets, outputs):
        if isinstance(targets, Data):
            assert hasattr(targets, 'y1') and hasattr(targets, 'y2')
            loss = targets.lam * self.loss(outputs, targets.y1.long()) + (1-targets.lam) * self.loss(outputs, targets.y2.long())
            accuracy = torch.tensor([0])
        else:
            loss = self.loss(outputs, targets)
            accuracy = self._calculate_accuracy(outputs, targets)
        return loss, accuracy


class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, mode="max"):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.early_stop = False
        self.delta = delta
        if mode == "max":
            self.best_score = -np.Inf
            self.check_func = lambda x, y: x >= y
        else:
            self.best_score = np.inf
            self.check_func = lambda x, y: x <= y

    def __call__(self, score):
        if self.check_func(score, self.best_score + self.delta):
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}\n")
            if self.counter >= self.patience:
                self.early_stop = True
                
        return self.early_stop


def train_GMNET(gmnet, dataset, batch_size=128, lr=1e-3, epochs=10, device='cuda:0'):
    gmnet.train()
    dataset = TripleSet(dataset)
    train_loader = MultiEpochsPYGDataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    optimizer = torch.optim.Adam(gmnet.parameters(), lr=lr, weight_decay=1e-4)

    for epoch in range(1, epochs + 1):
        print("====epoch {} ====".format(epoch))

        train_loss = 0.0
        for data_batch in train_loader:
            anchor_data, pos_data, neg_data = data_batch
            anchor_data, pos_data, neg_data = anchor_data.to(device), pos_data.to(device), neg_data.to(device)

            optimizer.zero_grad()

            x_1, y = gmnet(anchor_data, pos_data, pred_head=False)
            x_2, z = gmnet(anchor_data, neg_data, pred_head=False)

            loss = triplet_loss(x_1, y, x_2, z)
            loss = torch.mean(loss)
            loss.backward()
            optimizer.step()
            train_loss += loss

        print("Epoch [{}] Train_loss {}".format(epoch, train_loss / len(train_loader)))

    print("GMNET training done.")
    gmnet.eval()
    return gmnet

def mixup(gmnet, batch, alpha, sim_method='cos', normalize_method='softmax', temperature=1.0,):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 0.5

    lam = max(lam, 1 - lam)

    batch1 = batch.clone()
    data_list = list(batch1.to_data_list())
    data_list2 = data_list.copy()
    random.shuffle(data_list2)
    batch2 =  Batch.from_data_list(data_list2).cuda()

    h1, h2 = gmnet.dis_encoder(batch1, batch2, node_emd=True)
    h1, h2 = h1.detach(), h2.detach()
    
    for i in range(len(data_list)):
        data_list[i].emb = h1[batch1._slice_dict['x'][i] : batch1._slice_dict['x'][i + 1],:]
        data_list2[i].emb = h2[batch2._slice_dict['x'][i] : batch2._slice_dict['x'][i + 1],:]
    
    mixed_data_list = []

    for i in range(len(data_list)):
        if sim_method == 'cos':
            emb1 = data_list[i].emb / data_list[i].emb.norm(dim = 1)[:,None]
            emb2 = data_list2[i].emb / data_list2[i].emb.norm(dim = 1)[:,None]
            match = emb1 @ emb2.T / temperature 
        elif sim_method == 'abs_diff':
            match = -(data_list[i].emb.unsqueeze(1) - data_list2[i].emb.unsqueeze(0)).norm(dim = -1)

        if (normalize_method == 'softmax'):
            normalized_match = F.softmax(match.detach().clone(), dim = 0)
        elif(normalize_method == 'sinkhorn'):
            normalized_match = Sinkhorn(match.detach().clone())

        mixed_adj = (
            lam * to_dense_adj(data_list[i].edge_index, max_num_nodes=data_list[i].num_nodes)[0].double()
            + (1-lam) * normalized_match.double() @ to_dense_adj(
                data_list2[i].edge_index, max_num_nodes=data_list2[i].num_nodes
            )[0].double() @ normalized_match.double().T
        )

        mixed_adj[mixed_adj < 0.1] = 0
        mixed_x = lam * data_list[i].x + (1-lam) * normalized_match.float() @ data_list2[i].x

        edge_index, edge_weights = dense_to_sparse(mixed_adj)
        data = Data(x=mixed_x.float(), edge_index=edge_index, edge_weights=edge_weights, 
                    y1=data_list[i].y, y2=data_list2[i].y)
        mixed_data_list.append(data)

    b = Batch.from_data_list(mixed_data_list)
    b.lam = lam
    return b

def train(model, device, loader, optimizer, loss_func, grad_norm=None,
          mixup_flag=False, gmnet=None, alpha=0.1, sim_method='cos'):
    model.train()

    loss_all = 0
    acc_all = 0

    for data in loader:

        optimizer.zero_grad()

        data = data.to(device)

        if mixup_flag:
            data = mixup(gmnet, data, alpha=alpha, sim_method=sim_method)

        output = model(data)

        if mixup_flag:
            loss, acc = loss_func(data, output)
        else:
            loss, acc = loss_func(data.labels, output)

        loss.backward()
        optimizer.step()

        num_graphs = data.num_graphs

        loss_all += loss.item() * num_graphs
        acc_all += acc.item() * num_graphs

        if grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)

    return acc_all / len(loader.dataset), loss_all / len(loader.dataset)


@torch.no_grad()
def eval(model, device, loader, loss_func):
    model.eval()

    loss_all = 0
    acc_all = 0
    for data in loader:
        data = data.to(device)
        output = model(data)

        loss, acc = loss_func(data.labels, output)
        num_graphs = data.num_graphs

        loss_all += loss.item() * num_graphs
        acc_all += acc.item() * num_graphs

    return acc_all / len(loader.dataset), loss_all / len(loader.dataset)


def run_graph_pred(model, train_loader, valid_loader, test_loader, device='cuda:0', epochs=1000, lr=1e-2, 
                   step_size=50, gamma=0.5, patience=500, mixup_flag=False, gmnet=None, train_dataset=None,
                   gm_batch_size=None, gm_lr=None, gm_epochs=10, alpha=0.1, sim_method='cos', **kwargs):

    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size, gamma)
    early_stopper = EarlyStopping(patience=patience, verbose=False, mode='max')
    if mixup_flag:
        loss_func = MultiClassClassificationLossWithMixup()
    else:
        loss_func = MultiClassClassificationLoss()

    if mixup_flag:
        gmnet.to(device)
        gmnet = train_GMNET(gmnet, train_dataset, gm_batch_size, gm_lr, gm_epochs, device)

    all_results = {}
    for epoch in (pbar := trange(0, (epochs), desc = '[Epoch]', position = 1, leave=False)):
        train_acc, train_loss = train(
            model, device, train_loader, optimizer, loss_func, mixup_flag=mixup_flag, 
            gmnet=gmnet, alpha=alpha, sim_method=sim_method,
        )

        if scheduler is not None:
            scheduler.step()

        valid_acc, valid_loss = eval(model, device, valid_loader, loss_func)
        test_acc, test_loss = eval(model, device, test_loader, loss_func)

        results = {
            'train_acc': train_acc, 'train_loss': train_loss,
            'valid_acc': valid_acc, 'valid_loss': valid_loss,
            'test_acc': test_acc, 'test_loss': test_loss,
        }
        wandb.log(results)

        for k in results.keys():
            if k not in all_results.keys():
                all_results[k] = []
            all_results[k].append(results[k])

        msg = f"Epoch: {epoch}, {', '.join([f'{k}: {results[k]}' for k in sorted(results.keys())])}"
        pbar.set_description(msg)

        if early_stopper(valid_acc):
            break

        if early_stopper.counter == 0:
            tqdm.write(msg)
    
    temp = np.array(all_results['valid_acc'])
    best_idx = np.where(temp == temp.max())[0][-1]
    # best_idx = np.argmax(all_results[f'valid_acc'])
    final_results = {'final_acc': all_results['test_acc'][best_idx]}
    wandb.log(final_results)
    tqdm.write(f"Best epoch: {best_idx}, final ACC: {final_results['final_acc']}")

    return final_results


def main_fold(config, train_dataset, valid_dataset, test_dataset, device='cuda:0', seed=10,
              augment=False, train_guidance=False, subset_ratio=None):
    data_params, model_params, train_params = config.data, config.model, config.train
    data_name = data_params.data_name

    model_params.in_channels = train_dataset.x.shape[1]
    model_params.out_channels = train_dataset.n_classes
    model = load_downstream_model(model_params).to(device)

    if augment:
        augment_params = config.augment        
        ckpt_path = augment_params.ckpt_path

        train_dataset = [train_dataset.get(i) for i in train_dataset.indices()]
        data_list = train_dataset.copy()

        if subset_ratio is not None and subset_ratio < 1:
            from sklearn.model_selection import train_test_split

            stratify_labels = torch.cat([data.y for data in data_list])
            _, subset_idx = train_test_split(np.arange(len(data_list)), test_size=subset_ratio,
                                             random_state=seed, stratify=stratify_labels)
            data_list = [data_list[i] for i in subset_idx]

        if train_guidance:
            guidance_config = augment_params.guidance_config

            if guidance_config.diffusion.guidance_type == 'graph_class':
                guidance_config.guidance.output_dim = model_params.out_channels

            elif guidance_config.diffusion.guidance_type == 'graph_prop':
                data_list, property_attr, _, _ = get_properties(data_list, attr_name='prop_attr')
                guidance_config.guidance.output_dim = property_attr.shape[1]

            freeze_model = guidance_config.diffusion.freeze_model
            guidance_config, sampler, sampler_optim, sampler_sched = load_diffusion_guidance_optim(
                guidance_config, device, ckpt_path, freeze_model=freeze_model
            )

            ts = time.strftime('%b%d-%H:%M:%S', time.gmtime())
            guidance_ckpt_name = f"{ckpt_path.split('/')[-1].split('-')[0]}-guidance_{data_name}-r.{seed}-{ts}.pth"
            guidance_ckpt_dir = osp.join(*ckpt_path.split('/')[:-1])
            guidance_ckpt_path = osp.join(guidance_ckpt_dir, guidance_ckpt_name)

            num_epochs, batch_size = guidance_config.num_epochs, int(guidance_config.batch_size)
            dataloader = MultiEpochsPYGDataLoader(data_list, batch_size=batch_size, shuffle=True)

            print('Training guidance model ...')
            for epoch in (pbar := trange(0, (num_epochs), desc = '[Epoch]', position = 1, leave=False)):
                losses = []
                for _, bdata in enumerate(dataloader):
                    bdata = bdata.to(device)
                    loss = sampler(bdata)
                    loss.backward()
                    losses.append(loss.item())

                    torch.nn.utils.clip_grad_norm_(sampler.parameters(), 1.0)
                    sampler_optim.step()

                tqdm_log = f"[EPOCH {epoch+1:04d}] | train loss: {np.mean(losses):.3e}"
                pbar.set_description(tqdm_log)

            save_flag = guidance_config.save_flag
            if save_flag:
                print(f'Saving to {guidance_ckpt_path}')
                torch.save({
                    'model_config': guidance_config,
                    'model_state_dict': sampler.state_dict(),
                }, guidance_ckpt_path)
            
            torch.cuda.empty_cache()

        else:
            sampler = load_sampler(ckpt_path, device=device)

        def sample_graphs(data_list, nodes_max=None, edges_max=None, batch_size=1, sample_params={},
                          num_repeats=1, keys_to_keep=None):
            if nodes_max is not None or edges_max is not None:
                dataloader = get_batched_datalist(data_list, nodes_max=nodes_max, edges_max=edges_max)
                print(len(dataloader))
            else:
                batch_size = int(augment_params.batch_size)
                dataloader = MultiEpochsPYGDataLoader(data_list, batch_size=batch_size, shuffle=False)

            new_data_list = []
            for i in range(num_repeats):
                print(f'Augmenting... repeat {i + 1} / {num_repeats}')
                for _, bdata in tqdm(enumerate(dataloader)):
                    bdata = bdata.to(device)
                    bdata = sampler.sample(
                        bdata, device=device, **sample_params
                    ).cpu()
                    new_data_list.extend(batched_to_list(bdata, keys_to_keep=keys_to_keep))
            
            for i in range(len(new_data_list)):
                if data_name in ['Reddit_binary', 'Reddit_multi_5k', 'Reddit_multi_12k',
                                 'IMDB_binary', 'IMDB_multi', 'Collab']:
                    new_data_list[i].x = degree(new_data_list[i].edge_index[0]).unsqueeze(-1)

                new_data_list[i].edge_index, _ = add_remaining_self_loops(new_data_list[i].edge_index)

            return new_data_list

        def remove_attributes_from_dataset(dataset, keys_to_remove):
            if len(keys_to_remove) > 0:
                for data in dataset: 
                    for k in keys_to_remove:
                        data.pop(k, None)

            return dataset

        try:
            nodes_max = augment_params.nodes_max
        except:
            nodes_max = None
        
        try:
            edges_max = augment_params.edges_max
        except:
            edges_max = None

        batch_size = int(augment_params.batch_size)
        sample_params = OmegaConf.to_container(augment_params.sample)
        num_repeats = int(augment_params.num_repeats)
        keys_to_keep = train_dataset[0].keys()

        new_data_list = sample_graphs(data_list, nodes_max=nodes_max, edges_max=edges_max,
                                      batch_size=batch_size, sample_params=sample_params,
                                      num_repeats=num_repeats, keys_to_keep=keys_to_keep)

        replace_flag = augment_params.replace_flag
        keys_to_remove = list(set(train_dataset[0].keys()) - set(new_data_list[0].keys()))
        if replace_flag:
            train_dataset = new_data_list
        else:
            train_dataset = train_dataset + new_data_list
        train_dataset = remove_attributes_from_dataset(train_dataset, keys_to_remove)

        keys_to_keep = train_dataset[0].keys()
        augment_valid, augment_test = augment_params.augment_valid, augment_params.augment_test
        if augment_valid:
            data_list = [valid_dataset.get(i) for i in valid_dataset.indices()] 
            valid_dataset = sample_graphs(data_list, nodes_max=nodes_max, edges_max=edges_max,
                                          batch_size=batch_size, sample_params=sample_params,
                                          num_repeats=1, keys_to_keep=keys_to_keep)
        else:
            valid_dataset = [valid_dataset.get(i) for i in valid_dataset.indices()]
        valid_dataset = remove_attributes_from_dataset(valid_dataset, keys_to_remove)

        if augment_test:
            data_list = [test_dataset.get(i) for i in test_dataset.indices()] 
            test_dataset = sample_graphs(data_list, nodes_max=nodes_max, edges_max=edges_max,
                                          batch_size=batch_size, sample_params=sample_params,
                                          num_repeats=1, keys_to_keep=keys_to_keep)
        else:
            test_dataset = [test_dataset.get(i) for i in test_dataset.indices()]
        test_dataset = remove_attributes_from_dataset(test_dataset, keys_to_remove)

        del sampler, data_list, new_data_list
        torch.cuda.empty_cache()
        print('Using augmented view to train')

    batch_size = data_params.batch_size
    train_loader = MultiEpochsPYGDataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    valid_loader = MultiEpochsPYGDataLoader(valid_dataset, shuffle=False, batch_size=batch_size)
    test_loader = MultiEpochsPYGDataLoader(test_dataset, shuffle=False, batch_size=batch_size)

    mixup_flag = train_params.mixup_flag
    if mixup_flag:
        gmnet = GraphMatching(in_dim=train_dataset[0].x.shape[1], num_layers=train_params.gm_num_layers,
                              hidden=train_params.gm_hidden)
    else:
        gmnet = None

    results = run_graph_pred(model, train_loader, valid_loader, test_loader, device, 
                             train_dataset=train_dataset, gmnet=gmnet, **train_params)
    return results


def main(config, seed=10, fold=None, augment=False, train_guidance=False):
    set_seed(seed)

    device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    data_params = config.data
    try:
        subset_ratio = float(config.augment.subset_ratio)
    except:
        subset_ratio = None
    
    # cluster_path = 'data/misc/large_network_repository_network_repository_snap/processed/all_8clusters_seg-thres1000.pt'
    # cluster_path = 'data/misc/full_network_repository/processed/entr10_dens0.1_denm0_degm3_degv3_node4000_edge50000-feat_none-all_10clusters_ne.pt'
    cluster_path = 'data/misc/full_network_repository/processed/entr10_dens0.1_denm50_degm3_degv3_node4000_edge50000-feat_none-ext_github_stargazers-all_10clusters_ne.pt'
    dataset = load_data(data_params, return_loader=False, cluster_path=cluster_path)

    splits = dataset.splits
    splits = [splits] if not isinstance(splits, list) else splits
    if fold is not None:
        splits = [splits[fold]]

    all_results = {}
    for spl in splits:
        train_dataset = dataset[spl['train']]
        valid_dataset = dataset[spl['valid']]
        test_dataset = dataset[spl['test']]

        results = main_fold(config, train_dataset, valid_dataset, test_dataset, device=device, seed=seed,
                            augment=augment, train_guidance=train_guidance, subset_ratio=subset_ratio)

        for k in results.keys():
            if k not in all_results.keys():
                all_results[k] = []
            all_results[k].append(results[k])

    for k in all_results.keys():
        print("##################")
        print(f"Result of {len(splits)} folds")
        print(f"{k}: {np.mean(all_results[k])} +/- {np.std(all_results[k])}")
        print("##################")

if __name__ == "__main__":
    args, unknown = Parser().parse()
    cli = OmegaConf.from_dotlist(unknown)
    config = get_config(args.config, args.seed)
    config = OmegaConf.merge(config, cli)
    print(pformat(vars(config)))

    data_name = config.data.data_name
    augment = args.augment
    ckpt_path = config.augment.ckpt_path
    ckpt_prefix = ckpt_path.split('/')[-1].split('-')[0]

    thres = args.thres if args.thres is not None else config.augment.sample.thres
    config.augment.sample.thres = thres  # overwrite

    num_repeats = config.augment.num_repeats
    replace_flag = config.augment.replace_flag

    group_postfix = f'aug{str(augment)}'
    if augment:
        try:
            subset_ratio = float(config.augment.subset_ratio)
        except:
            subset_ratio = None

        if subset_ratio is not None and subset_ratio < 1:
            group_postfix += f'_sub{subset_ratio}'

        group_postfix += f'_{ckpt_prefix}_thres{str(thres)}_nrep{num_repeats}_repl{str(replace_flag)}'

    group_name = f'{data_name}_{group_postfix}'
    if args.train_guidance:
        guidance_type = config.augment.guidance_config.diffusion.guidance_type
        group_name = f'{guidance_type}_guide_' + group_name

    if len(args.prefix) > 0:
        group_name = args.prefix + '_' + group_name
    
    if (hasattr(config.augment.sample, 'inpaint_every_step') 
        and not config.augment.sample.inpaint_every_step):

        group_name = 'last_only_' + group_name

    datetime_now = time.strftime('%b%d-%H:%M:%S', time.gmtime())
    exp_name = f'{group_name}-r.{args.seed}-fold{str(args.fold)}-{datetime_now}'
    wandb.init(
        project='AdjDiff_GraphPropPred_MixupVer0',
        group=group_name,
        name=exp_name, 
        config=dict(config),
    )
    main(config, args.seed, args.fold, args.augment, args.train_guidance)