import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd.functional import hessian
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.gamma import Gamma
import matplotlib.pyplot as plt
import ipdb
from timeit import default_timer as timer
import numpy as np

from models_utils import base_learner
#from models_utils_debug import *

import time
import copy


class GBML(nn.Module):

    def __init__(self, input_size, z_dim, output_size, n_layers, lr0, reg, cnn=0, loss_type=0, use_batch_norm=False, model_type='cavia'):
        super().__init__()

        self.z_dim = z_dim
        self.reg = reg

        self.use_batch_norm = use_batch_norm
        self.model_type = model_type

        self.cnn = cnn
        self.loss_type = loss_type

        self.gamma = 0.9
        self.beta = 0.5

        self.use_v = False

        self.I = nn.Parameter(self.reg * torch.unsqueeze(torch.eye(self.z_dim), 0), requires_grad=False)


        if model_type in ['cavia', 'lava']:
            self.f = base_learner(input_size, z_dim, 32, output_size, n_layers, cnn=cnn, model_type=model_type, use_batch_norm=use_batch_norm)
            self.theta_0 = nn.Parameter(torch.zeros(1, z_dim), requires_grad=True)
            self.lr = nn.Parameter(torch.ones(1) * lr0)

        else:
            self.f = base_learner(input_size, 0, 64, output_size, n_layers, cnn=cnn, model_type=model_type, use_batch_norm=use_batch_norm)
            self.theta_0 = nn.ParameterList([nn.Parameter(torch.zeros(t_size)) for t_size in self.f.get_theta_shape()])
            self.num_layers = len(self.theta_0)

            self.lr = nn.ParameterList([nn.Parameter(torch.ones(t_size) * lr0) for t_size in self.f.get_theta_shape()])

            for i in range(self.num_layers):
                if self.theta_0[i].dim() > 1:
                    torch.nn.init.xavier_uniform_(self.theta_0[i])



    def update_v(self, batch, zt, zt1):
        x_q, y_q = batch[2], batch[3]
        lt, lt1 = 0, 0
        for i in range(x_q.shape[0]):
            lt += self.criterion(self.forward(x_q[i], zt), y_q[i])
            lt1 += self.criterion(self.forward(x_q[i], zt1), y_q[i])
        grad_t = torch.autograd.grad(outputs=(lt / x_q.shape[0]), inputs=zt, allow_unused=True)[0]
        grad_t1 = torch.autograd.grad(outputs=(lt1 / x_q.shape[0]), inputs=zt1, allow_unused=True)[0]

        if self.v is None:
            self.v = grad_t.detach()
        else:
            self.v = (grad_t1 + (1 - self.beta) * (self.v - grad_t)).detach()


    def mix(self, x_s, y_s, x_q, y_q, theta):

        s_idx = np.random.choice(x_s.shape[0], x_q.shape[0])
        max_l = self.f.learner.n_layers
        l = np.random.randint(0, max_l)
        alpha = 0.5
        beta = 0.5
        lamb = torch.from_numpy(np.random.beta(alpha, beta, x_q.shape[0])).view(-1, 1).float().to(x_s.device)
        h_s = self.f(x_s, theta, 0, l)
        h_q = self.f(x_q, theta, 0, l)
        x_mix = lamb * h_s[s_idx] + (1 - lamb) * h_q
        y_mix = lamb * y_s[s_idx] + (1 - lamb) * y_q
        y_hat_mix = self.f(x_mix, theta, l, -1)

        return y_hat_mix, y_mix

    def forward(self, x, theta=None):
        if theta is None:
            theta = self.theta_0
        return self.f(x, theta)
    
    def adapt(self, x, y, z=None, steps=1):

        def compute_loss_stateless_model(theta0, sample, target):
            batch = sample.unsqueeze(0)
            targets = target.unsqueeze(0)
            predictions = self.forward(batch, theta0)
            loss = self.criterion(predictions, targets)
            return loss

        ft_compute_grad = torch.func.grad(compute_loss_stateless_model)
        ft_compute_sample_grad = torch.vmap(ft_compute_grad, in_dims=(None, 0, 0))
            
        if self.model_type in ['cavia']:
            

            z = self.theta_0
            for i in range(steps):

                per_sample_grad = ft_compute_sample_grad(z, x, y).squeeze(1)
                z = z - self.lr * torch.mean(per_sample_grad, 0)

            det_H = 0

        elif self.model_type == 'lava':

            if self.use_batch_norm:
                theta_lava = self.theta_0.clone()

                def compute_loss_stateless_model2(theta0, x, y, idx):
                    #targets = target.unsqueeze(0)
                    predictions, _ = self.forward(x, theta0)
                    loss = torch.sum(self.criterion(predictions, y, mean=False)*idx)
                    return loss

                ft_compute_grad = torch.func.grad(compute_loss_stateless_model2) #compute_loss_stateless_model2(self, theta0, i, target)
                ft_compute_sample_grad = torch.vmap(ft_compute_grad, in_dims=(None, None, None, 0))
                idx = torch.arange(x.shape[0]).to(x.device)
                idx_hot = F.one_hot(idx).float()
                per_sample_grad = ft_compute_sample_grad(theta_lava, x, y, idx_hot).squeeze(1)

                theta = (theta_lava - self.lr * per_sample_grad).unsqueeze(1)

                theta_primes = theta

                _, stats_pre = self.forward(x, torch.mean(theta_primes, 0))

                H = torch.vmap(torch.func.hessian(lambda theta, x, y, stats: compute_loss_stateless_model(theta, x, y, stats)), in_dims=(0, 0, 0, None))(theta_primes, x, y, stats_pre)[:, 0, :, 0, :]

                H = (1 / (1 + self.reg)) * (H + self.I.repeat(H.shape[0], 1, 1))
                H_sum_inv = torch.linalg.inv(torch.sum(H, 0, keepdim=True))
                theta_grad_proj = torch.bmm(H, theta_primes.squeeze(1).unsqueeze(-1))
                theta = torch.bmm(H_sum_inv, torch.sum(theta_grad_proj, 0, keepdim=True))[:, :, 0]  # (1 x 2)

                det_H = torch.linalg.det(H_sum_inv).detach().cpu().item()
            else:
                z = self.theta_0
                per_sample_grad = ft_compute_sample_grad(z, x, y).squeeze(1)

                theta_primes = (z - self.lr * per_sample_grad).unsqueeze(1)
                H1 = torch.vmap(torch.func.hessian(lambda theta, x, y : compute_loss_stateless_model(theta, x, y)), in_dims=(0, 0, 0))(theta_primes, x, y)
                H = H1[:, 0, :, 0, :]
                H = (1 / (1 + self.reg)) * (H + self.reg * torch.eye(self.z_dim).to(x.device))
                H_sum_inv = torch.linalg.inv(torch.sum(H, 0, keepdim=True))
                theta_grad_proj = torch.bmm(H, theta_primes.squeeze(1).unsqueeze(-1))
                z = torch.bmm(H_sum_inv, torch.sum(theta_grad_proj, 0, keepdim=True))[:, :, 0] # (1 x 2)
                det_H = torch.linalg.det(H_sum_inv).detach().cpu().data.numpy()

        return z, det_H

    def criterion(self, y_pred, y, mean=True):

        if self.loss_type == 0:
            if mean:
                return torch.mean((y_pred - y) ** 2)
            return torch.mean((y_pred - y) ** 2, -1)
        elif self.loss_type == 1:
            if mean:
                return nn.CrossEntropyLoss(reduction='mean')(y_pred, y)
            return nn.CrossEntropyLoss(reduction='none')(y_pred, y)
        elif self.loss_type == 2:
            at = y[:, 0].long()
            rt = y[:, 1]
            probs = y[:, 2:]
            policy = self.sm(y_pred)
            rl = torch.cat([-(rt[i:i + 1]-0.0) * torch.log(torch.clamp(policy[i:i + 1, at[i]], min=1e-10)) for i in range(y.shape[0])])
            if mean:
                return nn.CrossEntropyLoss(reduction='mean')(y_pred, probs)
            return nn.CrossEntropyLoss(reduction='none')(y_pred, probs)
        elif self.loss_type == 3:
            if mean:
                return nn.BCEWithLogitsLoss(reduction='mean', pos_weight=torch.tensor([0.001]).to('cuda'))(y_pred, (y > 0.5)*1.)
            return nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([0.001]).to('cuda'))(y_pred, (y > 0.5)*1.)


class VR_MAML(nn.Module):

    def __init__(self, input_size, z_dim, output_size, n_layers, lr0, lr_out, gamma, cnn=0, deconv=0, loss_type=0):
        super().__init__()

        self.z_dim = z_dim
        self.hessian = hessian
        self.cnn = cnn
        self.loss_type = loss_type

        self.gamma = gamma #1.0#0.99
        self.first_time = True

        self.f = base_learner(input_size, 0, 64, output_size, n_layers, cnn=cnn, deconv=deconv)
        self.theta_0 = [None, nn.ParameterList([nn.Parameter(torch.zeros(t_size)) for t_size in self.f.get_theta_shape()])]
        self.num_layers = len(self.theta_0[1])

        self.old_grads = [w*0 for w in self.theta_0[1]]

        for i in range(self.num_layers):
            if self.theta_0[1][i].dim() > 1:
                torch.nn.init.xavier_uniform_(self.theta_0[1][i])

        self.lr_inner = lr0
        self.lr_outer = lr_out

    def copy_params(self):
        self.theta_0[0] = nn.ParameterList([copy.deepcopy(param) for param in self.theta_0[1]])

    def forward(self, x, theta=None, stats=None):

        if theta is None:
            theta = self.theta_0

        if self.svgd:
            return torch.mean(torch.cat([torch.unsqueeze(self.f(x, w)[0], 0) for w in theta], 0), 0), None

        return self.f(x, theta, stats)

    def adapt(self, x, y, theta_0,steps=1):

        y_hat = self.forward(x, theta_0)
        loss = self.criterion(y_hat, y, mean=False)

        theta_grad_s = torch.autograd.grad(outputs=torch.mean(loss), inputs=theta_0, create_graph=True)
        theta = [w - self.lr_inner * g for w, g in zip(theta_0, theta_grad_s)]

        return theta, 0


    def criterion(self, y_pred, y, mean=True):

        if self.loss_type == 0:
            if mean:
                return torch.mean((y_pred - y) ** 2)
            return torch.mean((y_pred - y) ** 2, -1)
        elif self.loss_type == 1:
            if mean:
                return nn.CrossEntropyLoss(reduction='mean')(y_pred, y)
            return nn.CrossEntropyLoss(reduction='none')(y_pred, y)
        elif self.loss_type == 2:
            at = y[:, 0].long()
            rt = y[:, 1]
            probs = y[:, 2:]
            policy = self.sm(y_pred)
            rl = torch.cat([-(rt[i:i + 1]-0.0) * torch.log(torch.clamp(policy[i:i + 1, at[i]], min=1e-10)) for i in range(y.shape[0])])
            if mean:
                return nn.CrossEntropyLoss(reduction='mean')(y_pred, probs)
            return nn.CrossEntropyLoss(reduction='none')(y_pred, probs)
        elif self.loss_type == 3:
            if mean:
                return nn.BCEWithLogitsLoss(reduction='mean', pos_weight=torch.tensor([0.001]).to('cuda'))(y_pred, (y > 0.5)*1.)
            return nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([0.001]).to('cuda'))(y_pred, (y > 0.5)*1.)


    def gradient_step(self, l_old, l_new):

        if self.first_time:
            ct = torch.autograd.grad(outputs=l_new, inputs=self.theta_0[1], create_graph=False)
            ct_new = [wct for wct in ct]
            self.theta_0[1] = nn.ParameterList([nn.Parameter((w - self.lr_outer * g)) for w, g in zip(self.theta_0[1], ct_new)])
            self.old_grads = [w.detach() for w in ct_new]
            self.first_time = False
            return

        ct1 = self.old_grads
        ct = torch.autograd.grad(outputs=l_new, inputs=self.theta_0[1], create_graph=False)
        dt1 = torch.autograd.grad(outputs=l_old, inputs=self.theta_0[0], create_graph=False)

        ct_new = [wct + (1 - self.gamma) * (wct1 - wdt1) for wct1, wdt1, wct in zip(ct1, dt1, ct)]
        self.theta_0[1] = nn.ParameterList([nn.Parameter((w - self.lr_outer * g)) for w, g in zip(self.theta_0[1], ct_new)])

        self.old_grads = [w.detach() for w in ct_new]

        return


