from .mvgrl import *
from .min_norm_solver import MinNormSolver, gradient_normalizers


class MVGRL_SVIAug_Pareto_Trainer(MVGRL_SVIAugTrainer):
    """
    Use MGDA for training.
    """
    @staticmethod
    def get_args(argv_list=None):
        parser = MVGRL_SVIAugTrainer.get_parser()
        parser.add_argument('--patience', type=int, default=999, 
                            help='Patient epochs to wait before early stopping. 0 for no early stopping.')
        parser.add_argument('--k', type=int, default=20, help='Number of propagation steps for graph diffusion.')
        parser.add_argument('--alpha', type=float, default=0.2, help='Propagation factor for graph diffusion.')
        parser.add_argument('--lr1', type=float, default=0.001, help='Learning rate of mvgrl.')
        parser.add_argument('--lr2', type=float, default=0.01, help='Learning rate of linear evaluator.')
        parser.add_argument('--wd1', type=float, default=0., help='Weight decay of mvgrl.')
        parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.')
        parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.')
        parser.add_argument("--hid-dim", type=int, default=512, help='Hidden layer dim.')
        parser.add_argument("--sample-size", type=int, default=4000, help='Subgraph size.')
        parser = MVGRL_SVIAugTrainer.add_sim_parameters(parser)
        parser.add_argument("--pareto", type=float, default=0.5, help='Pareto weight of the 3rd law.')
        parser.add_argument("--normalization-type", type=str, default='loss+', 
                            choices=['l2', 'loss', 'loss+', 'none'],
                            help='Normalization type for gradient.')
        
        if argv_list is None:
            args = parser.parse_args()
        else:
            args = parser.parse_args(argv_list)
        return args
    
    def data_process(self, data, cache_dir=CACHE_DIR + '/mvgrl'):
        args = self.args
        processed_data = super().data_process(data, cache_dir)
        processed_data["sim_graph"] = load_svi_similarity_graph(args.dataset, data.svi_emb, 
                                                                args.sim_type, args.sim_k, 
                                                                args.sim_epsilon, args.sim_sigma, self.device)
        return processed_data

    def train(self, data, **kwargs):
        args = self.args
        device = self.device
        graph = data.g
        road_feat = data.road_feat
        svi_emb = data.svi_emb
        diff_graph = kwargs['diff_graph']
        edge_weight = kwargs['diff_weight']
        diff_graph.edata['edge_weight'] = edge_weight
        sim_graph = kwargs['sim_graph']

        n_node = graph.number_of_nodes()
        sample_size = args.sample_size
        lbl1 = torch.ones(sample_size * 2)
        lbl2 = torch.zeros(sample_size * 2)
        lbl = torch.cat((lbl1, lbl2))
        lbl = lbl.to(device)

        model = MVGRL_SVIAugWithProjection(road_feat.shape[1], svi_emb.shape[1], 
                                           args.proj_dim_a, args.proj_dim_b, args.hid_dim, 
                                           road_feat_only=args.road_feat_only).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, 
                                     weight_decay=args.wd1)
        criterion = nn.BCEWithLogitsLoss()
        node_list = list(range(n_node))
        stopper = EarlyStopper(patience=args.patience)

        tasks = ['diff', '3rd']
        # Step 3: Training
        for epoch in range(args.epochs):
            # Sample sub-graph
            sample_idx = random.sample(node_list, sample_size)
            sample_idx = torch.LongTensor(sample_idx)
            g = dgl.node_subgraph(graph, sample_idx)
            dg = dgl.node_subgraph(diff_graph, sample_idx)
            sg = dgl.node_subgraph(sim_graph, sample_idx)

            feat_a = road_feat[sample_idx]
            feat_b = svi_emb[sample_idx]
            ew = dg.edata.pop('edge_weight')
            shuf_idx = torch.randperm(sample_size)
            shuf_feat_a = feat_a[shuf_idx, :]
            shuf_feat_b = feat_b[shuf_idx, :]

            g = g.to(device)
            dg = dg.to(device)
            sg = sg.to(device)
            feat_a = feat_a.to(device)
            feat_b = feat_b.to(device)
            ew = ew.to(device)
            shuf_feat_a = shuf_feat_a.to(device)
            shuf_feat_b = shuf_feat_b.to(device)

            # Prepare for Pareto
            loss_data = {}
            grads = {}
            scale = {}

            model.train()
            
            logits = model(g, dg, sg, feat_a, feat_b, shuf_feat_a, shuf_feat_b, ew)

            for i, logit in enumerate(logits):
                t = tasks[i]
                optimizer.zero_grad()
                loss_ = criterion(logit, lbl)
                loss_data[t] = loss_.detach()
                loss_.backward(retain_graph=True)
                grads[t] = []
                for param in model.parameters():
                    if param.grad is not None:
                        grads[t].append(param.grad.detach())
            
            gn = gradient_normalizers(grads, loss_data, args.normalization_type)
            for t in tasks:
                for gr_i in range(len(grads[t])):
                    grads[t][gr_i] = grads[t][gr_i] / gn[t]
            
            # Frank-Wolfe iteration to compute scales.
            sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
            for i, t in enumerate(tasks):
                scale[t] = float(sol[i])
            
            optimizer.zero_grad()  # Clear precomputed gradients
            # logits = model(g, dg, sg, feat_a, feat_b, shuf_feat_a, shuf_feat_b, ew)
            loss = 0.0
            for i, logit in enumerate(logits):
                t = tasks[i]
                loss_ = criterion(logit, lbl)
                loss_data[t] = loss_.detach().cpu().item()
                loss += loss_ * scale[t]
            loss.backward()
            optimizer.step()

            print(f'''Epoch: {epoch} | Total Loss: {loss.item():0.4f} | Diff Loss: {loss_data["diff"]:0.4f} | \
                  3rd Loss: {loss_data["3rd"]:0.4f} | Scale: {scale[tasks[0]]:0.4f}/{scale[tasks[1]]:0.4f}''')
            if args.patience > 0 and stopper.step(loss.item(), model):
                break

        # Step 4: Get embedding
        model.eval()
        model.load_state_dict(stopper.load_checkpoint())
        graph = graph.to(device)
        diff_graph = diff_graph.to(device)
        sim_graph = sim_graph.to(device)
        feat_a = road_feat.to(device)
        feat_b = svi_emb.to(device)
        edge_weight = edge_weight.to(device)
        emb = model.get_embedding(graph, diff_graph, sim_graph, feat_a, feat_b, edge_weight)
        return emb.cpu().detach()