import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from seed import set_seed

###### Loss Preliminaries ######
#  fpr est
def fpr_est_weighted(g, z0, i0, lam_g, p, c):
    fpr = torch.sum(torch.where(i0==0, 1.0, 1/p).unsqueeze(-1) * torch.sigmoid(c * (g(z0)-lam_g))).requires_grad_()
    return fpr / len(z0)

def fpr_est(g, z0, lam_g, c):
    fpr = torch.sum(torch.sigmoid(c * (g(z0)-lam_g))).requires_grad_()
    return fpr / len(z0)

# tpr est
def tpr_est(g, z1, lam_g, c):
    tpr = torch.sum(torch.sigmoid(c * (g(z1)-lam_g))).requires_grad_()
    return tpr / len(z1)

# loss
def objective(mode, g, z0, z1, i0, lam_g, p, beta=0.5, c=30):
    if mode == 'weighted':
        return -tpr_est(g, z1, lam_g, c) + beta * fpr_est_weighted(g, z0, i0, lam_g, p, c)
    elif mode == 'unweighted':
        return -tpr_est(g, z1, lam_g, c) + beta * fpr_est(g, z0, lam_g, c)

###### Two Layer NNs ######
class G1(nn.Module):
    def __init__(self, input_dim):
        super(G1, self).__init__()
        self.fc1 = nn.Linear(input_dim, 32)
        self.fc2 = nn.Linear(32, 1)
        self.dropout = nn.Dropout(p=0.5)  

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  
        x = self.fc2(x)
        return x

    def forward_np(self, x):
        x = self.forward(x)
        return x.detach().squeeze().numpy()

###### Linear Regression (under construction) #######
class G2(nn.Module):
    def __init__(self, input_dim):
        super(G2, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1)

    def nonneg(self):
        self.fc1.weight.data = torch.clamp(self.fc1.weight.data, min=0)

    def weight_np(self):
        return self.fc1.weight.detach().numpy()
    
    def forward(self, x):
        x = self.fc1(x)
        return x

###### Training Loop ######
def train_g(device, id, g, mode, training_param, beta, c, p, input_size, z0, z1, i0, num_epoch, batch_size, show_log, seed):
        # set seed
        set_seed(seed)

        # initialize g (if not provided) and lambda
        if g == None:
            if id == 1:
                g = G1(input_size)
            elif id == 2:
                g = G2(input_size)

        g = g.to(device)
        # lam_g = torch.tensor([0.0], requires_grad=True).to(device)
        lam_g = torch.zeros(1, requires_grad=True, device=device)

        # tesnorize and permute
        perm0 = torch.randperm(len(z0))
        z0 = torch.tensor(z0).to(device)[perm0]
        i0 = torch.tensor(i0).to(device)[perm0]
        
        perm1 = torch.randperm(len(z1))
        z1 = torch.tensor(z1).to(device)[perm1]


        # split data into train and validation
        i0_train, i0_val = i0[:int(len(i0)*0.8)], i0[int(len(i0)*0.8):]
        z0_train, z0_val = z0[:int(len(z0)*0.8)], z0[int(len(z0)*0.8):]
        z1_train, z1_val = z1[:int(len(z1)*0.8)], z1[int(len(z1)*0.8):]

        loader0 = DataLoader(TensorDataset(z0_train, i0_train), batch_size=batch_size, shuffle=True)
        loader1 = DataLoader(TensorDataset(z1_train), batch_size=batch_size, shuffle=True)
        
        # init variables
        param_groups = [
            {'params': list(g.fc1.parameters()), 'lr': training_param['lr_g'], 'weight_decay': training_param['wd_g']},
            {'params': [lam_g], 'lr': training_param['lr_lam'], 'weight_decay': training_param['wd_lam']}
        ]

        if id == 1:
            param_groups.append({'params': list(g.fc2.parameters()), 'lr': training_param['lr_g'], 'weight_decay': training_param['wd_g']})

        # optimizier
        optimizer = torch.optim.AdamW(param_groups)

        # validation
        version = 0
        best_loss = np.inf
        best_state = None

        # loss record
        train_losses = []

        for epoch in range(1, num_epoch+1):
            # Train
            g.train()
            epoch_loss = 0

            for (z0_batch, i0_batch), z1_batch in zip(loader0, loader1):
                z0_batch, i0_batch, z1_batch = z0_batch.to(device), i0_batch.to(device), z1_batch[0].to(device)
                optimizer.zero_grad() 
                cost = objective(mode, g, z0_batch, z1_batch, i0_batch, lam_g, p, beta, c) 
                cost.backward() 
                optimizer.step() 

                epoch_loss += cost.item() 

            epoch_loss /= len(loader0)  
            train_losses.append(epoch_loss)

            if id == 2:
                g.nonneg()

            # Validation
            g.eval()
            val_loss = 0
            with torch.no_grad():
                val_loss = objective(mode, g, z0_val, z1_val, i0_val, lam_g, p, beta, c)
            
            if val_loss < best_loss:
                best_loss = val_loss
                best_state = g.state_dict()
                version = epoch

            # Load the best model
            g.load_state_dict(best_state)
            
            # Log
            if show_log and epoch % 10 == 0:
                print(f'Epoch {epoch}/{num_epoch}, Loss: {cost.item()}, Version: {version}')

        return g, lam_g, train_losses

def compute_scores(g, z, device):
    g = g.to(device)
    z = torch.tensor(z, device=device)
    return list(g(z).detach().cpu().squeeze().numpy())