import os
import time
import wandb
import numpy as np
from pprint import pformat
from tqdm import tqdm, trange
from omegaconf import OmegaConf
from contextlib import contextmanager

import torch
import torch.distributed as dist
import torch_geometric
from torch_scatter import scatter
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops

from parsers import Parser, get_config
from utils.loader import (
    load_seed, load_device, load_ema, load_data,
    load_diffusion_model_optim, load_evaluater,
)
from utils.logger import set_log
from dataset.misc import batched_to_list

torch.set_num_threads(32)


class Trainer(object):
    def __init__(self, config, exp_name='run'):
        super(Trainer, self).__init__()

        self.exp_name = exp_name
        self.config = config
        self.log_folder_name, self.log_dir, self.ckpt_dir = set_log(self.config)

        self.seed = load_seed(self.config.train.seed)
        self.device = load_device()
        self.device_id = f'cuda:{self.device[0]}' if isinstance(self.device, list) else self.device

        if hasattr(self.config.data, 'cluster_path'):
            cluster_path = self.config.data.cluster_path
        else:
            cluster_path = None

        self.dataset = load_data(self.config.data, return_loader=False, cluster_path=cluster_path)
        self.train_loader = self.dataset.get_dataloader(
            split='train', shuffle=True, **self.config.data
        )

        if hasattr(self.dataset, 'splits') and 'valid' in self.dataset.splits:
            self.valid_loader = self.dataset.get_dataloader(
                split='valid', shuffle=False, **self.config.data
            )
        else:
            self.valid_loader = None

        if hasattr(self.dataset, 'splits') and 'test' in self.dataset.splits:
            self.test_loader = self.dataset.get_dataloader(
                split='test', shuffle=False, **self.config.data
            )
        else:
            self.test_loader = None

        if hasattr(self.config.data, 'cluster'):
            self.config.model.max_degree = self.dataset.max_degrees[self.config.data.cluster]
        else:
            self.config.model.max_degree = self.dataset.max_degree

        try:
            self.config.model.max_degrees = self.dataset.max_degrees
        except:
            pass

        extra_kwargs = {}
        if self.config.model.target == 'GuidedGNN':
            extra_kwargs['num_classes'] = self.dataset.kmeans.n_clusters
        self.model, self.optimizer, self.scheduler = load_diffusion_model_optim(
            self.config, self.device, **extra_kwargs
        )
        self.ema = load_ema(self.model, decay=self.config.train.ema)

        self.num_epochs = self.config.train.num_epochs
        self.grad_norm = self.config.train.grad_norm
        self.save_interval = self.config.train.save_interval
        self.eval_interval = self.config.train.eval_interval
        self.sample_from_empty = self.config.sample.sample_from_empty

        # if self.config.diffusion.target == 'binary':
        #     from diffusion.diffusion_utils import EmpiricalEmptyGraphGenerator

        #     self.model.initial_graph_sampler = EmpiricalEmptyGraphGenerator(
        #         self.dataset.get_split_dataset('train')
        #     )

        if self.test_loader is not None:
            self.evaluater = load_evaluater(
                self.dataset.get_split_dataset('test'), exp_name=exp_name, device=self.device_id
            )
            self.sample_params = self.config.sample
        
        del self.dataset

    @contextmanager
    def ema_scope(self, context=None):
        self.ema.store(self.model.parameters())
        self.ema.copy_to(self.model.parameters())
        if context is not None:
            print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            self.ema.restore(self.model.parameters())
            if context is not None:
                print(f"{context}: Restored training weights")

    def update_stats(self, attr_key, stats):
        attr = getattr(self, attr_key)
        for k in attr.keys():
            try:
                attr[k].append(stats[k].item())
            except:
                pass

        setattr(self, attr_key, attr)

    def get_mean_stats(self, stats, prefix=''):
        output = {}
        for k, v in stats.items():
            try:
                output[f'{prefix}_{k}'] = np.mean(v)
            except:
                pass

        return output

    def train(self, valid_ema=True):
        self.config.exp_name = self.exp_name
        self.ckpt_name = f'{self.exp_name}'
        print('\033[91m' + f'{self.ckpt_name}' + '\033[0m')

        # -------- Training --------
        for epoch in (pbar := trange(0, (self.num_epochs), desc = '[Epoch]', position = 1, leave=False)):

            self.train_stats = {'loss': []}
            self.valid_stats = {'loss': []}
            self.valid_ema_stats = {'loss': []}
            self.test_metrics = {'degree_mmd': [], 'spectral_mmd': [], 'clustering_mmd': [], 'orbits_mmd': []}
            t_start = time.time()

            self.model.train()

            for _, train_bdata in enumerate(self.train_loader):

                self.optimizer.zero_grad()
                train_bdata = train_bdata.to(self.device_id)
                loss = self.model(train_bdata)
                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                self.optimizer.step()

                self.ema.update(self.model.parameters())
                self.update_stats('train_stats', {'loss': loss})

            if self.config.train.lr_schedule:
                self.scheduler.step()

            # TODO: fix stats
            if self.valid_loader is not None:
                self.model.eval()
                for _, valid_bdata in enumerate(self.valid_loader):   
                    valid_bdata = valid_bdata.to(self.device_id)
                    with torch.no_grad():
                        valid_loss = self.model(valid_bdata)
                        self.update_stats('valid_stats', {'loss': valid_loss})
                        if valid_ema:
                            with self.ema_scope():
                                ema_loss = self.model(valid_bdata)
                                self.update_stats('valid_ema_stats', {'loss': ema_loss})

            if (epoch + 1) % self.eval_interval == 0 and self.test_loader is not None:
                test_stats = self.test(self.sample_from_empty)
                self.update_stats('test_metrics', test_stats)

            log_dict = {'epoch': epoch, 'time': time.time()-t_start}
            log_dict.update(self.get_mean_stats(self.train_stats, prefix='train'))
            tqdm_log = f"[EPOCH {epoch+1:04d}] | train loss: {log_dict['train_loss']:.3e}"
            if self.valid_loader is not None:
                log_dict.update(self.get_mean_stats(self.valid_stats, prefix='valid'))
                tqdm_log += f" | valid loss: {log_dict['valid_loss']:.3e}"
                if valid_ema:
                    log_dict.update(self.get_mean_stats(self.valid_ema_stats, prefix='valid_ema'))
                    tqdm_log += f" | valid ema loss: {log_dict['valid_ema_loss']:.3e}"
            if (epoch + 1) % self.eval_interval == 0 and self.test_loader is not None:
                log_dict.update(self.get_mean_stats(self.test_metrics, prefix='test'))

            wandb.log(log_dict)
            pbar.set_description(tqdm_log)
            if epoch % self.eval_interval == self.eval_interval - 1:
                tqdm.write(tqdm_log)

            # -------- Save checkpoints --------
            if epoch % self.save_interval == self.save_interval - 1:
                save_name = f'_{epoch+1}' if epoch < self.config.train.num_epochs - 1 else ''
                torch.save({
                    'model_config': self.config,
                    'model_state_dict': self.model.state_dict(), 
                    'ema_state_dict': self.ema.state_dict(),
                    }, f'{self.ckpt_dir}/{self.ckpt_name + save_name}.pth')

    def test(self, sample_from_empty=False):
        # -------- Sampling --------
        self.model.eval()
        if self.config.diffusion.target == 'binary':
            if sample_from_empty:
                generated_data_list = self.model.sample(num_samples=10).cpu().to_data_list()
            else:
                assert self.test_loader is not None
                generated_data_list = []
                for _, bdata in tqdm(enumerate(self.test_loader)):
                    bdata = bdata.to(self.device_id)
                    bdata = self.model.sample(bdata).cpu()
                    generated_data_list.extend(batched_to_list(bdata))

        else:
            assert self.test_loader is not None
            generated_data_list = []
            for _, bdata in tqdm(enumerate(self.test_loader)):
                bdata = bdata.to(self.device_id)
                # import ipdb; ipdb.set_trace()
                full_edge_attr = self.model.sample(
                    bdata, return_all=False, device=self.device_id, **self.sample_params
                )
                is_edge_indices = full_edge_attr.nonzero(as_tuple=True)[0]
                edge_index = bdata.full_edge_index[:, is_edge_indices]
                edge_index = torch.cat([edge_index, edge_index.flip(0)],dim=-1)
                bdata.edge_index = edge_index

                # edge_slice = bdata.batch[bdata.edge_index[0]]
                # edge_slice = scatter(torch.ones_like(edge_slice), edge_slice, dim_size=bdata.num_graphs)
                # edge_slice = torch.nn.functional.pad(edge_slice, (1,0), 'constant', 0)
                # edge_slice = torch.cumsum(edge_slice, 0)
                # bdata._slice_dict['edge_index'] = edge_slice
                # bdata._inc_dict['edge_index'] = bdata._inc_dict['full_edge_index']
                # generated_data_list.extend(bdata.to_data_list())
                generated_data_list.extend(batched_to_list(bdata.cpu()))

        metrics = self.evaluater(generated_data_list)
        print(metrics)
        return metrics

# TODO: change to DDP
if __name__ == '__main__':
    ts = time.strftime('%b%d-%H:%M:%S', time.gmtime())
    args, unknown = Parser().parse()
    cli = OmegaConf.from_dotlist(unknown)
    config = get_config(args.config, args.seed)
    config = OmegaConf.merge(config, cli)
    print(pformat(vars(config)))

    # -------- Train --------
    seed = config.train.seed
    exp_name = f'{args.prefix}-r.{seed}-{ts}'
    diff_type = config.diffusion.target
    # try:
    #     data_name = config.data.data_name
    # except:
    #     data_name = '_'.join(sorted(config.data.data_name_list))
    # print(data_name)

    run = wandb.init(
        project='AdjacencyDiffVer1',
        # group=f'{diff_type}_{data_name}',
        group=f'{diff_type}_{args.prefix}',
        name=exp_name, 
        # config=config
    )
    trainer = Trainer(config, exp_name=exp_name)
    trainer.train(exp_name)

    test_seeds = range(args.seed, args.seed + 5)
    for sd in test_seeds:
        load_seed(sd)
        metrics = {}
        try:
            orig_metrics = trainer.test(sample_from_empty=False)
            for k in orig_metrics.keys():
                metrics[f'final_orig_{k}'] = orig_metrics[k]
        except:
            pass

        try:
            empty_metrics = trainer.test(sample_from_empty=True)
            for k in empty_metrics.keys():
                metrics[f'final_empty_{k}'] = empty_metrics[k]
        except:
            pass

        wandb.log(metrics)