import os.path as osp
import time
import argparse
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import LayerNorm, Linear, BatchNorm1d
from tqdm import tqdm
from ipdb import set_trace as stc

import torch_geometric.transforms as T
from torch_geometric.loader import RandomNodeSampler
from torch_geometric.nn import GroupAddRev, SAGEConv
from torch_geometric.utils import index_to_mask

from ogb.nodeproppred import Evaluator, PygNodePropPredDataset  # noqa


class GNNBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, norm):
        super().__init__()
        if norm == 'layer':
            self.norm = LayerNorm(in_channels, elementwise_affine=True)
        elif norm == 'batch':
            self.norm = BatchNorm1d(in_channels)
        self.conv = SAGEConv(in_channels, out_channels)

    def reset_parameters(self):
        self.norm.reset_parameters()
        self.conv.reset_parameters()

    def forward(self, x, edge_index, dropout_mask=None):
        x = self.norm(x).relu()
        if self.training and dropout_mask is not None:
            x = x * dropout_mask
        return self.conv(x, edge_index)


class RevGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_groups=2, norm='layer'):
        super().__init__()

        self.dropout = dropout

        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)
        if norm == 'layer':
            self.norm = LayerNorm(hidden_channels, elementwise_affine=True)
        elif norm == 'batch':
            self.norm = BatchNorm1d(hidden_channels)

        assert hidden_channels % num_groups == 0
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = GNNBlock(
                hidden_channels // num_groups,
                hidden_channels // num_groups,
                norm,
            )
            self.convs.append(GroupAddRev(conv, num_groups=num_groups))

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.norm.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index):
        x = self.lin1(x)

        # Generate a dropout mask which will be shared across GNN blocks:
        mask = None
        if self.training and self.dropout > 0:
            mask = torch.zeros_like(x).bernoulli_(1 - self.dropout)
            mask = mask.requires_grad_(False)
            mask = mask / (1 - self.dropout)

        for conv in self.convs:
            x = conv(x, edge_index, mask)
        x = self.norm(x).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin2(x)


class JKRevGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, teacher_channels, out_channels, num_layers,
                 dropout, num_groups=2, norm='layer'):
        super().__init__()

        self.dropout = dropout

        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(teacher_channels, out_channels)
        if norm == 'layer':
            self.norm = LayerNorm(hidden_channels, elementwise_affine=True)
            self.last_norm = LayerNorm(teacher_channels, elementwise_affine=True)
        elif norm == 'batch':
            self.norm = BatchNorm1d(hidden_channels)
            self.last_norm = BatchNorm1d(teacher_channels)

        assert hidden_channels % num_groups == 0
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = GNNBlock(
                hidden_channels // num_groups,
                hidden_channels // num_groups,
                norm,
            )
            self.convs.append(GroupAddRev(conv, num_groups=num_groups))
        
        self.W_jk = torch.nn.Linear(hidden_channels * (num_layers + 1), teacher_channels)

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.norm.reset_parameters()
        self.last_norm.reset_parameters()
        self.W_jk.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index):
        layer_outputs = []
        
        x = self.lin1(x)
        layer_outputs.append(x.clone())

        # Generate a dropout mask which will be shared across GNN blocks:
        mask = None
        if self.training and self.dropout > 0:
            mask = torch.zeros_like(x).bernoulli_(1 - self.dropout)
            mask = mask.requires_grad_(False)
            mask = mask / (1 - self.dropout)

        for conv in self.convs:
            x = conv(x, edge_index, mask)
            layer_outputs.append(x.clone())
        x = torch.cat(layer_outputs, dim=1)
        x = self.W_jk(x)
        h_norelu = x.clone()
        x = self.last_norm(x).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        h_relu = x.clone()

        return self.lin2(x), h_norelu, h_relu


class NoJKRevGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, teacher_channels, out_channels, num_layers,
                 dropout, num_groups=2, norm='layer'):
        super().__init__()

        self.dropout = dropout

        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(teacher_channels, out_channels)
        if norm == 'layer':
            self.norm = LayerNorm(hidden_channels, elementwise_affine=True)
            self.last_norm = LayerNorm(teacher_channels, elementwise_affine=True)
        elif norm == 'batch':
            self.norm = BatchNorm1d(hidden_channels)
            self.last_norm = BatchNorm1d(teacher_channels)

        assert hidden_channels % num_groups == 0
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = GNNBlock(
                hidden_channels // num_groups,
                hidden_channels // num_groups,
                norm,
            )
            self.convs.append(GroupAddRev(conv, num_groups=num_groups))
        
        self.W_jk = torch.nn.Linear(hidden_channels, teacher_channels)

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.norm.reset_parameters()
        self.last_norm.reset_parameters()
        self.W_jk.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index):
        # layer_outputs = []
        
        x = self.lin1(x)
        # layer_outputs.append(x.clone())

        # Generate a dropout mask which will be shared across GNN blocks:
        mask = None
        if self.training and self.dropout > 0:
            mask = torch.zeros_like(x).bernoulli_(1 - self.dropout)
            mask = mask.requires_grad_(False)
            mask = mask / (1 - self.dropout)

        for conv in self.convs:
            x = conv(x, edge_index, mask)
            # layer_outputs.append(x.clone())
        # x = torch.cat(layer_outputs, dim=1)
        x = self.W_jk(x)
        h_norelu = x.clone()
        x = self.last_norm(x).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        h_relu = x.clone()

        return self.lin2(x), h_norelu, h_relu


class RevGNN_knowledge(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_groups=2, norm='layer'):
        super().__init__()

        self.dropout = dropout

        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)
        if norm == 'layer':
            self.norm = LayerNorm(hidden_channels, elementwise_affine=True)
        elif norm == 'batch':
            self.norm = BatchNorm1d(hidden_channels)

        assert hidden_channels % num_groups == 0
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = GNNBlock(
                hidden_channels // num_groups,
                hidden_channels // num_groups,
                norm,
            )
            self.convs.append(GroupAddRev(conv, num_groups=num_groups))

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.norm.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index):
        x = self.lin1(x)

        # Generate a dropout mask which will be shared across GNN blocks:
        mask = None
        if self.training and self.dropout > 0:
            mask = torch.zeros_like(x).bernoulli_(1 - self.dropout)
            mask = mask.requires_grad_(False)
            mask = mask / (1 - self.dropout)

        for conv in self.convs:
            x = conv(x, edge_index, mask)
        h_norelu = x.cpu().detach()
        x = self.norm(x).relu()
        h_relu = x.cpu().detach()
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin2(x), h_norelu, h_relu


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='ogbn-products')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--train_num_parts', type=int, default=10)
parser.add_argument('--test_num_parts', type=int, default=1)
parser.add_argument('--hidden_channels', type=int, default=320)
parser.add_argument('--num_groups', type=int, default=2)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--norm', type=str, default='norm')
parser.add_argument('--teacher_dir', type=str, default='[your_path]/codes_epic/products/teacher')
parser.add_argument('--outputs_dir', type=str, default='[your_path]/codes_epic/products/outputs')
parser.add_argument('--student_dir', type=str, default='[your_path]/codes_epic/products/student')
parser.add_argument('--embs_relu', action='store_true', default=False)
parser.add_argument('--r_t', type=float, default=1.0)
parser.add_argument('--r_e', type=float, default=1.0)
parser.add_argument('--seed', type=int, default=0)


args = parser.parse_args()

set_seed(args.seed)

teacher_dir = args.teacher_dir
student_dir = args.student_dir
outputs_dir = args.outputs_dir
print(f'The teacher_path is {teacher_dir}')
print(f'The teacher_path is {student_dir}')
print(f'The teacher_path is {outputs_dir}')

embs_norelu = torch.from_numpy(np.load(osp.join(teacher_dir, 'embs_t.npy')))
embs_relu = torch.from_numpy(np.load(osp.join(teacher_dir, 'embs_t.npy')))
args.teacher_channels = embs_relu.shape[1]
logits = torch.from_numpy(np.load(osp.join(teacher_dir, 'logits.npy')))



device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
# model_state_dict = torch.load(model_path, map_location=device)['model_state_dict']

transform = T.Compose([T.ToDevice(device), T.ToSparseTensor()])
root = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'products')
dataset = PygNodePropPredDataset('ogbn-products', root, transform=T.AddSelfLoops())
obs_dataset = PygNodePropPredDataset('ogbn-products', root, transform=T.AddSelfLoops())
evaluator = Evaluator(name='ogbn-products')

data = dataset[0]
obs_data = obs_dataset[0]
split_idx = dataset.get_idx_split()
for split in ['train', 'valid', 'test']:
    data[f'{split}_mask'] = index_to_mask(split_idx[split], data.y.shape[0])
    obs_data[f'{split}_mask'] = index_to_mask(split_idx[split], data.y.shape[0])

train_num = len(split_idx['train'])
valid_num = len(split_idx['valid'])
test_num = len(split_idx['test'])
obs_num = train_num + valid_num

obs_data.num_nodes = obs_num
obs_data.x = obs_data.x[:obs_num]
obs_data.y = obs_data.y[:obs_num]
obs_data.train_mask = obs_data.train_mask[:obs_num]
obs_data.valid_mask = obs_data.valid_mask[:obs_num]
obs_data.test_mask = obs_data.test_mask[:obs_num]
edge_index = obs_data.edge_index
mask0 = edge_index[0] >= obs_num
mask1 = edge_index[1] >= obs_num
mask = ~(mask0 + mask1)
obs_data.edge_index = edge_index[:, mask]

obs_data.embs_norelu = embs_norelu[:obs_num]
obs_data.embs_relu = embs_relu[:obs_num]
obs_data.logits = logits[:obs_num]

train_loader = RandomNodeSampler(obs_data, num_parts=args.train_num_parts, shuffle=True, num_workers=5)
test_loader = RandomNodeSampler(data, num_parts=args.test_num_parts, num_workers=5)

model = NoJKRevGNN(
    in_channels=dataset.num_features,
    hidden_channels=args.hidden_channels,
    teacher_channels=args.teacher_channels,
    out_channels=dataset.num_classes,
    num_layers=args.num_layers,  # You can try 1000 layers for fun
    dropout=args.dropout,
    num_groups=args.num_groups,
    norm=args.norm,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)


def train(args, epoch, criterion_gt, criterion_sl):
    model.train()
    r_t, r_e = args.r_t, args.r_e

    pbar = tqdm(total=len(train_loader))
    pbar.set_description(f'Training epoch: {epoch:03d}')

    total_loss = total_examples = 0
    for data in train_loader:
        optimizer.zero_grad()

        # Memory-efficient aggregations:
        data = transform(data)
        if args.embs_relu:
            embs_t = data.embs_relu[data.train_mask]
        else:
            embs_t = data.embs_norelu[data.train_mask]
        logits = data.logits
        out, embs_norelu, embs_relu = model(data.x, data.adj_t)
        out = out.log_softmax(dim=1)

        num_n, dim_h = embs_t.shape[0], embs_t.shape[1]

        loss_gt = criterion_gt(out[data.train_mask], data.y[data.train_mask].view(-1))
        if r_t != 0:
            loss_t = criterion_sl(out[data.train_mask], logits[data.train_mask])
            if r_e != 0:
                if args.embs_relu:
                    loss_e = ((embs_relu[data.train_mask] - embs_t).norm(p=2)**2) / (num_n * dim_h)
                else:
                    loss_e = ((embs_norelu[data.train_mask] - embs_t).norm(p=2)**2) / (num_n * dim_h)
                loss = (1 / (1 + r_t + r_e)) * loss_gt + (r_t / (1 + r_t + r_e)) * loss_t + (r_e / (1 + r_t + r_e)) * loss_e
            else:
                loss = (1 / (1 + r_t)) * loss_gt + (r_t / (1 + r_t)) * loss_t
        else:
            loss = loss_gt

        loss.backward()
        optimizer.step()

        total_loss += float(loss) * int(data.train_mask.sum())
        total_examples += int(data.train_mask.sum())
        pbar.update(1)

    pbar.close()

    return total_loss / total_examples

@torch.no_grad()
def test(epoch):
    model.eval()

    y_true = {"train": [], "valid": [], "test": []}
    y_pred = {"train": [], "valid": [], "test": []}

    pbar = tqdm(total=len(test_loader))
    pbar.set_description(f'Evaluating epoch: {epoch:03d}')

    for data in test_loader:
        data = transform(data)
        out, _, _ = model(data.x, data.adj_t)
        out = out.argmax(dim=-1, keepdim=True)

        for split in ['train', 'valid', 'test']:
            mask = data[f'{split}_mask']
            y_true[split].append(data.y[mask].cpu())
            y_pred[split].append(out[mask].cpu())

        pbar.update(1)

    pbar.close()

    train_acc = evaluator.eval({
        'y_true': torch.cat(y_true['train'], dim=0),
        'y_pred': torch.cat(y_pred['train'], dim=0),
    })['acc']

    valid_acc = evaluator.eval({
        'y_true': torch.cat(y_true['valid'], dim=0),
        'y_pred': torch.cat(y_pred['valid'], dim=0),
    })['acc']

    test_acc = evaluator.eval({
        'y_true': torch.cat(y_true['test'], dim=0),
        'y_pred': torch.cat(y_pred['test'], dim=0),
    })['acc']

    return train_acc, valid_acc, test_acc


times = []
best_val = 0.0
final_train = 0.0
final_test = 0.0

criterion_gt = torch.nn.NLLLoss()
criterion_sl = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)

for epoch in range(1, args.epochs+1):
    start = time.time()
    loss = train(args, epoch, criterion_gt, criterion_sl)
    train_acc, val_acc, test_acc = test(epoch)
    if val_acc > best_val:
        best_val = val_acc
        final_train = train_acc
        final_test = test_acc
        if args.embs_relu:
            torch.save(model.state_dict(), osp.join(args.student_dir, f'seed{args.seed}_nojkrev{args.num_layers}-{args.hidden_channels}_ind_lr{args.lr}_{args.norm}norm_rt{args.r_t}_re{args.r_e}_embsrelu_testsubgraph{args.test_num_parts}.pth'))
        else:
            torch.save(model.state_dict(), osp.join(args.student_dir, f'seed{args.seed}_nojkrev{args.num_layers}-{args.hidden_channels}_ind_lr{args.lr}_{args.norm}norm_rt{args.r_t}_re{args.r_e}_embsnorelu_testsubgraph{args.test_num_parts}.pth'))
    print(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
          f'Test: {test_acc:.4f}')
    times.append(time.time() - start)

print(f'Final Train: {final_train:.4f}, Best Val: {best_val:.4f}, '
      f'Final Test: {final_test:.4f}')
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")