from typing import  List
import itertools

import torch
from torch import nn
import torch.nn.functional as F
from torch import optim

from .gnn import GINE
from infrastructure import pytorch_util as ptu
import numpy as np
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.nn.models import GAT
from torch_geometric.data import Data, Batch

from torch.optim.lr_scheduler import CosineAnnealingLR
from .utils import CosineAnnealingWithFlatTail

class DiscriminatorNet(nn.Module):
    def __init__(
            self,
            ob_feature_dim: int,
            embed_dim: int,
            n_gcn_layers: int,
            n_layers: int,
            layer_size: int,
            learning_rate: float
        ):

        super().__init__()

        self.gcn_out_dim = embed_dim 
        self.gcn = GINE(ob_feature_dim, embed_dim, self.gcn_out_dim, n_gcn_layers).to(ptu.device)

        self.network = ptu.build_mlp(
                input_size= self.gcn_out_dim*n_gcn_layers,
                output_size=1,
                n_layers=n_layers,
                size=layer_size,
                ).to(ptu.device)

        self.params = itertools.chain(self.gcn.parameters(), self.network.parameters())
        self.optimizer = optim.RAdam(
            self.params,
            lr = learning_rate/20
        )

        self.scheduler = CosineAnnealingWithFlatTail(self.optimizer, T_max=1000, eta_min=5e-6, flat_after=1000)

    def forward(self, obs):
        
        data_list = []
        for o in obs:
            graph = Data(x = o.node_type.to(ptu.device).to(torch.float32).squeeze(),
                          edge_index = o.edge_index.to(ptu.device),
                         edge_attr = o.edge_attr.to(torch.float32).to(ptu.device))
            data_list.append(graph)

        batch = Batch.from_data_list(data_list)
        embeddings = self.gcn(batch.x, edge_index=batch.edge_index, 
                              edge_attr=batch.edge_attr, dropout=0, batch=batch.batch)

        embed_graph = self.network(embeddings)
        pred = F.sigmoid(embed_graph)
        return pred
    
    def evaluate(self, real, generator):
        length = min(len(real), len(generator))
        pred_real = self(real[0:length])
        pred_generator = self(generator[0:length])
        loss = torch.mean(pred_generator - pred_real) 
        return loss, length
    
    def update(self, real, generator, step, pretrain=False):
        batch_size =128
        steps_p_batch = batch_size
        shuffle = np.random.permutation(len(generator))
        gen_shuff = [generator[s] for s in shuffle]
        shuffle_r = np.random.permutation(len(real))
        real_shuff = [real[s] for s in shuffle_r]

        disc_infos = []
        baseline_steps = 0
        while baseline_steps < len(gen_shuff):
            end_step = baseline_steps + steps_p_batch if \
                baseline_steps + steps_p_batch < len(gen_shuff) \
                else len(gen_shuff)
            if end_step <= len(gen_shuff): 
                if end_step%len(real) < baseline_steps%len(real):
                    real_data = real_shuff[baseline_steps%len(real):] + real_shuff[:end_step%len(real)]
                    loss, length = self.evaluate(real_data, gen_shuff[baseline_steps:end_step])
                else:
                    loss, length = self.evaluate(real_shuff[baseline_steps%len(real):end_step%len(real)], gen_shuff[baseline_steps:end_step])
                disc_infos.append(ptu.to_numpy(loss))
                if pretrain: 
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                elif step:
                    #continue
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    self.scheduler.step()
                baseline_steps=end_step

     
        if pretrain: 
            return {
                "Discriminator Pretrain Loss": np.mean(disc_infos), 
                "Disc Pretrain Generator Datapoints": len(generator) 
            }
        else:
            return {
                "Discriminator Loss": np.mean(disc_infos),
                "Discriminator Learning Rate": self.scheduler.get_last_lr()[0],
                "Disc Generator Datapoints": len(generator)
            }
    