import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

seed = 5
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cpu')  # torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'device: {device}')

class Generator(nn.Module):
    def __init__(self, latent_dim, obs_dim, linear=True):
        super(Generator, self).__init__()

        if linear:
            self.f = nn.Linear(latent_dim, obs_dim, bias=False)
        else:
            self.f = nn.Sequential(
                nn.Linear(latent_dim, obs_dim),
                nn.LeakyReLU(.1),
                nn.Linear(obs_dim, obs_dim),
            )

    def forward(self, input):
        return self.f(input)


class Classifier(nn.Module):
    def __init__(self, obs_dim, linear=True):
        super(Classifier, self).__init__()

        self.obs_dim = obs_dim


        if linear:
            self.embedder = nn.Linear(obs_dim, 1, bias=False)
        else:
            self.embedder = nn.Sequential(
                nn.Linear(obs_dim, obs_dim),
                nn.ReLU(),
                nn.Linear(obs_dim, obs_dim // 2),
                nn.ReLU(),
                nn.Linear(obs_dim // 2, 1)
            )


    def forward(self, input):
        return self.embedder(input)



def environment(causal_mean, causal_sigma, env_mean_prior, eta=0.5, n=500, fixed_noise=None):
    caus_dim, env_dim = len(causal_mean), len(env_mean_prior)
    env_mean, env_sigma = env_mean_prior + torch.randn(env_dim) * .75, torch.tensor([1.])

    if fixed_noise is None:
        y = torch.bernoulli(torch.tensor([eta] * n))
        caus_noise = torch.randn(n, caus_dim)
        env_noise = torch.randn(n, env_dim)
    else:
        y, caus_noise, env_noise = fixed_noise
    caus_sample = (caus_noise * torch.sqrt(causal_sigma)) + torch.ger(2 * y - 1, causal_mean)
    env_sample = (env_noise * torch.sqrt(env_sigma)) + torch.ger(2 * y - 1, env_mean)
    latents = torch.cat([caus_sample, env_sample], dim=1)
    return latents, y


def compute_penalty(loss, dummy_w):
    return torch.norm(torch.autograd.grad(loss, dummy_w, create_graph=True)[0]) ** 2


d_c, d_e = 3, 6
obs_dim = d_c + d_e  # must be >= d_c + d_e
eta = 0.5
test_envs_n = 5000
lambd = 1e-11
maxE = 10
opt_iters = 60000
lr = 3e-2
causal_mean, causal_sigma = torch.randn(d_c) + torch.randn(d_c) * .5, torch.tensor([2.])
env_mean_prior = torch.randn(d_e) * 1.5
causal_classifier = 2 * causal_mean.to(device) / causal_sigma
print(f'd_c: {d_c}, d_e: {d_e}, test_envs_n: {test_envs_n}, lambd: {lambd}, maxE: {maxE}, opt_iters: {opt_iters}, lr: {lr}')

f = Generator(d_c + d_e, obs_dim, linear=True).to(device)
phi = Classifier(obs_dim, linear=True).to(device)
start_weights = phi.embedder.weight.data.clone()
dummy_w = nn.Parameter(torch.Tensor([1.0])).to(device)

for iter in range(5):

    with torch.no_grad():
        train_environments = [environment(causal_mean, causal_sigma, env_mean_prior, n=10000) for _ in range(1, maxE+1)]
        test_obs_n = 10000
        test_y = torch.bernoulli(torch.tensor([eta] * test_obs_n))
        test_caus_noise = torch.randn(test_obs_n, d_c)
        test_env_noise = torch.randn(test_obs_n, d_e)
        test_environments = [environment(causal_mean, causal_sigma, env_mean_prior, n=test_obs_n,
                                         fixed_noise=(test_y, test_caus_noise, test_env_noise)) for _ in range(test_envs_n)]

        train_losses, test_losses, shift_losses = np.zeros(maxE), np.zeros(maxE), np.zeros(maxE)
        train_accs, test_accs, shift_accs = np.zeros(maxE), np.zeros(maxE), np.zeros(maxE)
        caus_losses, caus_accs = np.zeros(maxE), np.zeros(maxE)

    for E in range(1, maxE+1):
        print(f'RUNNING E = {E}')
        with torch.no_grad():
            phi.embedder.weight = nn.Parameter(start_weights.clone())
        opt = torch.optim.Adam(phi.parameters(), lr=lr)

        for i in range(opt_iters):
            loss, penalty, acc = 0., 0., 0.
            for latents, labels in train_environments[:E]:
                latents, labels = latents.to(device), labels.to(device)
                out = phi(latents).squeeze() * dummy_w
                env_loss = F.binary_cross_entropy_with_logits(out, labels)
                loss += env_loss / E
                penalty += compute_penalty(env_loss, dummy_w) / E
                acc += ((out > 0).long() == labels.long()).sum().float() / len(labels) / E

            opt.zero_grad()
            obj = loss
            if E != 1:
                obj = loss * lambd + penalty
            obj.backward()
            opt.step()
            if (i+1) % 10000 == 0:
                print(i + 1, loss.item(), penalty.item(), obj.item())
                train_losses[E-1], train_accs[E-1] = loss.item(), acc

                test_loss, test_acc = 0., 0.
                worst_loss, worst_acc = 0., 1.
                with torch.no_grad():
                    for env_ind, (test_latents, test_labels) in enumerate(test_environments):
                        test_latents, test_labels = test_latents.to(device), test_labels.to(device)

                        out = phi(test_latents).flatten()
                        loss = F.binary_cross_entropy_with_logits(out, test_labels)
                        test_loss += loss.item()
                        test_acc += ((out > 0).long() == test_labels.long()).sum().float() / len(test_labels)

                        modified_latents = test_latents.clone()
                        modified_latents[:, d_c:] *= -1
                        out = phi(modified_latents).flatten()
                        loss = F.binary_cross_entropy_with_logits(out, test_labels)
                        worst_loss = np.max([loss.item(), worst_loss])
                        acc = ((out > 0).long() == test_labels.long()).sum().float() / len(test_labels)
                        worst_acc = np.min([acc, worst_acc])
                        if E == 1 and env_ind == 0:
                            caus_out = modified_latents[:, :d_c] @ causal_classifier
                            caus_loss = F.binary_cross_entropy_with_logits(caus_out, test_labels).item()
                            caus_acc = ((caus_out > 0).long() == test_labels.long()).sum().float() / len(test_labels)
                            caus_losses, caus_accs = np.repeat(caus_loss, maxE), np.repeat(caus_acc, maxE)

                    print(f'iter = {iter} E = {E} (CAUS ACC: {caus_acc})')
                    print(f'ACC: {test_acc / test_envs_n}')
                    print(f'WORST LOSS: {worst_loss}')
                    print(f'WORST ACC: {worst_acc}')

                if i == opt_iters - 1 or worst_acc < 5e-3:
                    test_losses[E-1] = test_loss / test_envs_n
                    test_accs[E-1] = test_acc / test_envs_n
                    shift_losses[E-1] = worst_loss
                    shift_accs[E-1] = worst_acc
                    stacked = np.vstack([
                        train_losses, train_accs,
                        test_losses, test_accs,
                        shift_losses, shift_accs,
                        caus_losses, caus_accs])
                    np.save(f'resultsnew{iter}_seed{seed}.npy', stacked)
                    print(f'RESULTS FOR E = {E}:\n{stacked}')
                    break


