
import time
import os
import torch
from torch.optim import Adam
from torch_geometric.data import DataLoader
import numpy as np
from torch.autograd import grad
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
from torch_scatter import scatter


def global_mean_pool(x, batch, dim):
    size = int(batch.max().item() + 1)
    return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')

def global_max_pool(x, batch, dim):
    size = int(batch.max().item() + 1)
    return scatter(x, batch, dim=dim, dim_size=size, reduce='max')


def node_filip(x, y, batch, T):

    criterion = nn.CrossEntropyLoss()
    
    x = F.normalize(x, dim = 1)
    y = F.normalize(y ,dim = 1)
    logit = torch.mm(x, y.transpose(1,0))
    
    diag = torch.diag(logit, 0)
    size = int(batch.max().item()+1)
    diag = scatter(diag, batch, dim=0, dim_size = size, reduce='mean')
    mask = torch.eye(len(diag)).to(x.device)
    
    # 2D-3D
    logit1 = global_max_pool(logit, batch, 1)
    logit1 = global_mean_pool(logit1, batch, 0)
    logit1 = (1-mask) * logit1 + torch.diag(diag)
    logit1 = torch.div(logit1, T)
    
    # 3D-2D
    logit2 = global_max_pool(logit, batch, 0)
    logit2 = global_mean_pool(logit2, batch, 1)
    logit2 = (1-mask) * logit2 + torch.diag(diag)
    logit2 = torch.div(logit2.transpose(1,0), T)
    
    label = torch.arange(logit1.shape[0]).long().to(logit1.device)

    CL_loss_1 = criterion(logit1, label)
    CL_loss_2 = criterion(logit2, label)
    pred1 = logit1.argmax(dim=1, keepdim=False)
    pred2 = logit2.argmax(dim=1, keepdim=False)

    CL_acc1 = pred1.eq(label).sum().detach().cpu().item() * 1. / logit1.shape[0]
    CL_acc2 = pred2.eq(label).sum().detach().cpu().item() * 1. / logit2.shape[0]

    return (CL_loss_1 + CL_loss_2) / 2, (CL_acc1 + CL_acc2) / 2



def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr

def do_CL(X, Y):
    
    X = F.normalize(X, dim=-1)
    Y = F.normalize(Y, dim=-1)


    criterion = nn.CrossEntropyLoss()
    B = X.size()[0]
    logits = torch.mm(X, Y.transpose(1, 0))  # B*B
    logits = torch.div(logits, 0.1)
    labels = torch.arange(B).long().to(logits.device)  # B*1

    CL_loss = criterion(logits, labels)
    pred = logits.argmax(dim=1, keepdim=False)
    CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B

    
    return CL_loss, CL_acc


def dual_CL(X, Y):
    CL_loss_1, CL_acc_1 = do_CL(X, Y)
    CL_loss_2, CL_acc_2 = do_CL(Y, X)
    return (CL_loss_1 + CL_loss_2) / 2, (CL_acc_1 + CL_acc_2) / 2

class run_pretrain_holimol():
    r"""
    The base script for running different 3DGN methods.
    """
    def __init__(self):
        pass
        
    def run(self, device, train_dataset, model_2D, model_3D, loss_func, epochs=500, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=50, weight_decay=0, 
        energy_and_force=False, p=100, save_dir='', log_dir=''):
        r"""
        The run script for training and validation.
        
        Args:
            device (torch.device): Device for computation.
            train_dataset: Training data.
            valid_dataset: Validation data.
            test_dataset: Test data.
            model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet.
            loss_func (function): The used loss funtion for training.
            evaluation (function): The evaluation function. 
            epochs (int, optinal): Number of total training epochs. (default: :obj:`500`)
            batch_size (int, optinal): Number of samples in each minibatch in training. (default: :obj:`32`)
            vt_batch_size (int, optinal): Number of samples in each minibatch in validation/testing. (default: :obj:`32`)
            lr (float, optinal): Initial learning rate. (default: :obj:`0.0005`)
            lr_decay_factor (float, optinal): Learning rate decay factor. (default: :obj:`0.5`)
            lr_decay_step_size (int, optinal): epochs at which lr_initial <- lr_initial * lr_decay_factor. (default: :obj:`50`)
            weight_decay (float, optinal): weight decay factor at the regularization term. (default: :obj:`0`)
            energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)    
            p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy during training. (default: :obj:`100`)
            save_dir (str, optinal): The path to save trained models. If set to :obj:`''`, will not save the model. (default: :obj:`''`)
            log_dir (str, optinal): The path to save log files. If set to :obj:`''`, will not save the log files. (default: :obj:`''`)
        
        """        

        model_2D = model_2D.to(device)
        model_3D = model_3D.to(device)
        num_params = sum(p.numel() for p in model_2D.parameters()) + sum(p.numel() for p in model_3D.parameters())
        print(f'#Params: {num_params}')
        model_param_group = []
        model_param_group.append({'params': model_2D.parameters(), 'lr': 0.001})
        model_param_group.append({'params': model_3D.parameters(), 'lr': 0.001})
        optimizer = Adam(model_param_group, lr=lr, weight_decay=weight_decay)
        #scheduler = StepLR(optimizer, step_size=lr_decay_step_size, gamma=lr_decay_factor)
        
        train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
        #valid_loader = DataLoader(valid_dataset, vt_batch_size, shuffle=False)
        #test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False)
        best_train = float('inf')
        #best_valid = float('inf')
        #best_test = float('inf')
            
        if save_dir != '':
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
        if log_dir != '':
            if not os.path.exists(log_dir):
                os.makedirs(log_dir)
            writer = SummaryWriter(log_dir=log_dir)
        
        for epoch in range(1, epochs + 1):
            print("\n=====Epoch {}".format(epoch), flush=True)
            
            print('\nTraining...', flush=True)
            train_loss = self.train(model_2D, model_3D, device, train_loader, optimizer, epoch)

            
            #print('\n\nEvaluating...', flush=True)
            #valid_mae = self.val(model, valid_loader, energy_and_force, p, evaluation, device)

            #print('\n\nTesting...', flush=True)
            #test_mae = self.val(model, test_loader, energy_and_force, p, evaluation, device)

            #print()
            #print({'Train': train_mae, 'Validation': valid_mae, 'Test': test_mae})

            #if log_dir != '':
            #    writer.add_scalar('train_mae', train_mae, epoch)
            #    writer.add_scalar('valid_mae', valid_mae, epoch)
            #    writer.add_scalar('test_mae', test_mae, epoch)
            
            #if valid_mae < best_valid:
            #    best_valid = valid_mae
            #    best_test = test_mae
            if epoch % 2 == 0:
                print('Saving checkpoint...')
                checkpoint = {'epoch': epoch, 'model_2D_state_dict': model_2D.state_dict(), 'model_3D_state_dict': model_3D.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'num_params': num_params}
                torch.save(checkpoint, f'/your/checkpoint/path/{epoch}.pt')

            #scheduler.step()

        #print(f'Best validation MAE so far: {best_valid}')
        #print(f'Test MAE when got best validation result: {best_test}')
        
        if log_dir != '':
            writer.close()

    def train(self,  model_2D, model_3D, device, loader, optimizer, epoch):
        r"""
        The script for training.
        
        Args:
            model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet.
            optimizer (Optimizer): Pytorch optimizer for trainable parameters in training.
            train_loader (Dataloader): Dataloader for training.
            energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)    
            p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy during training. (default: :obj:`100`)
            loss_func (function): The used loss funtion for training. 
            device (torch.device): The device where the model is deployed.

        :rtype: Traning loss. ( :obj:`mae`)
        
        """   
        start_time = time.time()

        model_2D.train()
        model_3D.train()
        CL_loss_accum, CL_acc_accum = 0, 0
        AE_loss_accum, AE_acc_accum = 0, 0
        CL_acc_accum1=0
        CL_acc_accum2=0
        

        for step, (batch, batch1, orig_batch, orig_batch1, idx) in enumerate(loader):
            batch = batch.to(device)
            batch1 = batch1.to(device)
            orig_batch.to(device)
            orig_batch1.to(device)

            frag_batch = [torch.tensor([i] * idx[i]) for i in range(len(idx))]
            frag_batch = torch.cat(frag_batch).to(device)
            molecule_2D_repr_orig,_,_ = model_2D.forward_cl_node(orig_batch.x, orig_batch.edge_index, orig_batch.edge_attr, orig_batch.batch)
            molecule_2D_repr = model_2D.forward_cl(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            frag_2D_repr = model_2D.forward_pool(batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch)
            molecule_2D_repr_mixed = model_2D.attention(frag_2D_repr, frag_batch)
            molecule_2D_repr_mixed = model_2D.forward_project(molecule_2D_repr_mixed)
            
            loss_2D = model_2D.loss_cl_graphcl(molecule_2D_repr, molecule_2D_repr_mixed)    

            
            molecule_3D_repr_orig, molecule_3D_repr = model_3D.forward_cl_node(orig_batch.x[:, 0], 0.0*(torch.randn_like(orig_batch.positions).cuda()) + orig_batch.positions, orig_batch.batch) 
            
            frag_3D_repr = model_3D.forward_pool(orig_batch1.x[:, 0], 0.0*(torch.randn_like(orig_batch.positions).cuda()), orig_batch1.batch)
            


            molecule_3D_repr_mixed = model_3D.attention(frag_3D_repr, frag_batch)
            
            molecule_3D_repr_mixed = model_3D.forward_project(molecule_3D_repr_mixed)
            
            loss_3D = model_3D.loss_cl_graphcl(molecule_3D_repr, molecule_3D_repr_mixed)   
            
            CL_loss, CL_acc = dual_CL(molecule_2D_repr, molecule_3D_repr)
            
            binc = torch.bincount(orig_batch.batch)

            batch_size = len(binc) 
            offset = torch.cat([torch.Tensor([0]).long().cuda(), torch.cumsum(binc, dim=0)], dim=0)
            
            frag_binc = torch.bincount(frag_batch)
            frag_offset = torch.cat([torch.Tensor([0]).long().cuda(), torch.cumsum(frag_binc, dim=0)], dim=0) 
            node_idx_all = torch.tensor([]).cuda()
            dihedral_labels_all = torch.tensor([]).cuda()
            
            component_batch = orig_batch.components[:,0]

            for b in range(batch_size):
                
                b_offset = offset[b]
                num_dihedral = orig_batch.dihedral_num[b]
                
                component_batch[offset[b]:offset[b+1]] += frag_offset[b]
                
                if num_dihedral == 0:
                    continue
                    
                node_idx = orig_batch.dihedral_anchors[b,:4*num_dihedral].clone()
                dihedral_labels = orig_batch.dihedral_labels[b,:num_dihedral].clone()
                
                node_idx += b_offset
                
                node_idx_all = torch.cat([node_idx_all, node_idx])
                dihedral_labels_all = torch.cat([dihedral_labels_all, dihedral_labels])
            
            frag_2d = model_2D.average_pool(molecule_2D_repr_orig, component_batch)
            frag_3d = model_3D.average_pool(molecule_3D_repr_orig, component_batch)
            
            frag_2d = model_2D.forward_project(frag_2d)
            frag_3d = model_3D.forward_project(frag_3d)
            
            frag_CL_loss, frag_CL_acc = node_filip(frag_2d, frag_3d, frag_batch, 0.1)
                
            anchor = torch.index_select(molecule_2D_repr_orig, 0, node_idx_all.long()).reshape(-1, 4, 300)
            anchor_reverse = torch.zeros_like(anchor).cuda()
            anchor_reverse[:,0,:] = anchor[:,3,:]
            anchor_reverse[:,1,:] = anchor[:,2,:]
            anchor_reverse[:,2,:] = anchor[:,1,:]
            anchor_reverse[:,3,:] = anchor[:,0,:]
            
            anchor = anchor.reshape(anchor.shape[0], -1)
            anchor_out = model_2D.forward_aux2(anchor)
            loss_anchor1 = model_2D.ce2(anchor_out, dihedral_labels_all.long())
            
            anchor_reverse = anchor_reverse.reshape(anchor.shape[0], -1)
            anchor_out2 = model_2D.forward_aux2(anchor_reverse)
            loss_anchor2 = model_2D.ce2(anchor_out2,dihedral_labels_all.long())
            
            loss_anchor = (loss_anchor1 + loss_anchor2)/2
            correct1 = (anchor_out.argmax(dim=1) == dihedral_labels_all.long()).sum()
            correct2 = (anchor_out2.argmax(dim=1) == dihedral_labels_all.long()).sum()
            
            acc = (correct1 + correct2) / (2 * orig_batch.dihedral_num.sum())
            AE_loss_accum += loss_anchor.detach().cpu().item()
            AE_acc_accum += acc
            
            CL_loss_accum += (CL_loss.detach().cpu().item() + frag_CL_loss.detach().cpu().item())/2 #CL_loss.d
            CL_acc_accum1 += CL_acc
            CL_acc_accum2 += frag_CL_acc

            loss = 0
            if epoch < 0:
                loss += loss_2D + loss_3D + loss_anchor 
            else:
                loss += loss_2D + (CL_loss+frag_CL_loss)/2 + loss_3D + loss_anchor 
            
            
            #if args.alpha_2 > 0:
            #    loss += AE_loss * args.alpha_2
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_2D.parameters(), 5)
            torch.nn.utils.clip_grad_norm_(model_3D.parameters(), 5)
            optimizer.step()
        global optimal_loss
        
        CL_loss_accum /= len(loader)
        CL_acc_accum1 /= len(loader)
        CL_acc_accum2 /= len(loader)
        AE_loss_accum /= len(loader)
        AE_acc_accum /= len(loader)
        temp_loss = CL_loss_accum
        
        
        
        print('CL Loss: {:.5f}\tCL graph Acc: {:.5f}\tCL local Acc: {:.5f}\tAE Loss: {:.5f}\tAE Acc: {:.5f}\tTime: {:.5f}'.format(
        CL_loss_accum, CL_acc_accum1, CL_acc_accum2, AE_loss_accum, AE_acc_accum, time.time() - start_time))

        return temp_loss

