from re import L
from meta_curvature import Curvature
import torch
import torch.nn as nn
import torch.nn.functional as F
import ipdb
import numpy as np
from collections import OrderedDict
from models.leo import ImageEncoder
from utils import reshape_time_series_input, flatten_params_torch, reshape_param, reshape_time_series_x


def create_linear_layer(input_dim, output_dim, num_layers, hidden_dim):
    hidden = hidden_dim
    layers = []
    layers.append(nn.Linear(input_dim, hidden))
    for i in range(num_layers):
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden, hidden))
    layers.append(nn.ReLU())
    layers.append(nn.Linear(hidden, output_dim))
    return nn.Sequential(*layers)


class SetTransformer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        hidden = 128
        self.encoder = nn.Sequential(nn.Linear(input_dim, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU())

        self.q = nn.Linear(hidden, output_dim)
        self.v = nn.Linear(hidden, output_dim)
        self.k = nn.Linear(hidden, output_dim)

    def forward(self, x, y):
        xcat = torch.cat((x, y), -1) #.mean(0).unsqueeze(0)
        xcat = self.encoder(xcat)

        q = self.q(xcat) # batch x dim
        k = self.k(xcat) # batch x dim
        v = self.v(xcat)

        score = F.softmax(q @ k.T, 0) # batch x batch
        out = score.T @ v

        return out.mean(0, keepdim=True)


class DeepSets(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        hidden = 128
        self.net = nn.Sequential(nn.Linear(input_dim, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))

        self.out_net = OutNet(hidden, output_dim)

    def forward(self, x, y):
        x = torch.cat((x, y), -1) #.mean(0).unsqueeze(0)
        z = self.net(x)
        z = z.mean(0, keepdim=True)
        return self.out_net(z)
        return z


class OutNet(nn.Module):
    def __init__(self, hidden, output_dim):
        super().__init__()
        self.out_net = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, output_dim))

    def forward(self, x):
        return self.out_net(x)


class PointNet(nn.Module):
    def __init__(self, input_dim, output_dim, pooling='max'):
        super().__init__()
        print("input_output", input_dim, output_dim)
        hidden = 128
        self.conv1 = nn.Conv1d(input_dim, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, hidden, 1)
        #self.conv3 = nn.Conv1d(64, output_dim, 1)
        
        self.net = nn.Sequential(self.conv1, nn.ReLU(), self.conv2, nn.ReLU(), self.conv3)

        #self.out_net = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, output_dim))

        #self.out_net = OutNet(output_dim, output_dim)
        self.out_net = OutNet(hidden, output_dim)

        self.pooling = pooling

    def forward(self, x, y):
        x = torch.cat((x, y), -1) #.mean(0).unsqueeze(0)
        # x \in num_points x dim
        # x_permute in dim x numpoints
        x_permute = x.permute(1,0).unsqueeze(0)
        z = self.net(x_permute)
        
        if self.pooling == 'max':
            z = z.max(2)[0].contiguous()
        else:
            torch.mean(z, 2).contiguous()

        return self.out_net(z)
        #return z


class GradientEncoder(nn.Module):
    def __init__(self, x_input_dim, y_output_dim, k, num_layers, criterion, use_img):
        super().__init__()

        ## Dimensionality of the encoder network
        hidden = 40
        num_layers = 3 # Num layers 7 is the original one
        self.use_img =use_img

        self.grad_network = create_linear_layer(x_input_dim, y_output_dim, num_layers, hidden)

        hidden = 128
        self.eigen_vectors = nn.Parameter(torch.randn((k, self.num_params)) * 0.01) # Linear version
        #self.eigen_vectors = nn.Parameter(torch.randn((hidden, self.num_params)) * 0.01) # Linear version
        self.criterion = criterion

        self.theta_0 = list(self.grad_network.parameters())

        self.learning_rate = nn.Parameter(torch.rand(1)*0.01)

        #hidden = 128
        self.out_net = OutNet(k, k)

    @property
    def num_params(self):
        theta = list(self.grad_network.parameters())
        return np.sum([np.prod(x.shape) for x in theta])

    def real_forward(self, x, theta):
        n = len(self.theta_0)
        h = x
        count = 0
        for i in range(0, n // 2 - 1):
            h = F.relu(F.linear(h, theta[2*i], bias = theta[2*i+1]))
            count += 1
        out = F.linear(h, theta[n-2], bias = theta[n-1])
        return out


    def forward(self, x, y, steps=1, return_loss=False):
        batch_size = x.shape[0]
        device = x.device

        theta = self.theta_0
        out = self.real_forward(x, theta)
        l1 = self.criterion(out, y) #+ hinge_loss
        zs = self.adapt(l1)

        return zs
        
    def adapt(self, loss):
        theta_0 = self.theta_0

        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=theta_0, retain_graph=True, create_graph=True)
        grad_flat = flatten_params_torch(theta_grad_s, self.num_params)
        zs = -self.eigen_vectors @ grad_flat
        zs = self.out_net(zs)
        return zs.unsqueeze(0)




class HyperLearner(nn.Module):
    def __init__(self, dim, k, learnt_learning_rate, num_layers, activation, output_dim=1, loss='mse', mse_dim=1, num_output_params=None, imgs=0, traj_length=None, encoder='deepsets', num_hypernet_layers=1):
        super().__init__()

        self.activation = activation
        self.loss = loss
        self.mse_dim = mse_dim
        self.imgs = imgs
        self.traj_length = traj_length

        hidden = 40
        if self.imgs:
            num_channels = 3
            filters = 64
            img_hidden = 64
            img_encoder_dim = 64

            self.image_encoder = ImageEncoder(num_channels, filters, img_hidden, img_encoder_dim)
            encoder_input_dim = img_hidden + self.traj_length * 1 + img_hidden # 1 = action_dim
            x_input_dim = img_hidden + self.traj_length
            y_output_dim = img_hidden
        else:
            #encoder_input_dim = dim + output_dim
            x_input_dim = dim
            y_output_dim = output_dim

        encoder_input_dim = x_input_dim + y_output_dim
        layers = []
        layers.append(nn.Linear(dim, hidden))
        for i in range(num_layers):
            layers.append(nn.ReLU())
            layers.append(nn.Linear(hidden, hidden))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden, output_dim))

        self.net = nn.Sequential(*layers)
        self.theta_0 = list(self.net.parameters())

        if learnt_learning_rate:
            self.learning_rate = nn.Parameter(torch.rand(1)*0.01)
        else:
            self.learning_rate = 0.01

        hidden_encoder = k
        self.k = k

        if encoder == 'set-transformer':
            self.encoder = SetTransformer(encoder_input_dim, hidden_encoder)
        elif encoder == 'deepsets':
            self.encoder = DeepSets(encoder_input_dim, hidden_encoder)
        elif encoder == 'pointnet':
            self.encoder = PointNet(encoder_input_dim, hidden_encoder)
        elif encoder == 'gradient':
            self.encoder = GradientEncoder(x_input_dim, y_output_dim, hidden_encoder, 3, criterion=self.criterion, use_img=imgs)

        ## Defining Hypernetwork
        
        if num_output_params is None:
            num_output_params = self.num_params
        
        if num_hypernet_layers == 1:
            hypernet_modules = [nn.Linear(hidden_encoder, num_output_params)]
        else:
            hidden = 128
            hypernet_modules = []
            hypernet_modules.append(nn.Linear(hidden_encoder, hidden))
            hypernet_modules.append(nn.ReLU())
            for i in range(num_hypernet_layers - 1):
                hypernet_modules.append(nn.Linear(hidden, hidden))
                hypernet_modules.append(nn.ReLU())
            hypernet_modules.append(nn.Linear(hidden, num_output_params))

        self.hyper_net = nn.Sequential(*hypernet_modules)
        # self.hyper_net = nn.Sequential(nn.Linear(hidden_encoder, hidden), 
        #                 nn.ReLU(), nn.Linear(hidden,hidden), nn.ReLU(), nn.Linear(hidden, num_output_params))

        if self.imgs:
            self.linear_layer = nn.Linear(self.traj_length * img_hidden, img_hidden)
        self.condition_net = nn.Sequential(nn.Linear(dim + self.k, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, output_dim))

    def encode_img(self, x):
        h = self.tmp_img_encoder(x)
        h = torch.flatten(h, 1)
        return self.tmp_linear(h)

    def encode(self, x, y):
        if self.imgs:
            x,y = reshape_time_series_input(x, y, self.traj_length, self.image_encoder, self.linear_layer)

        if len(y.shape) == 1:
            y = y.unsqueeze(-1)

        encoded1 = self.encoder(x,y)

        #theta = self.learning_rate * self.hyper_net(encoded1).squeeze(0)
        theta = self.hyper_net(encoded1).squeeze(0)
        return theta


    def forward(self, x, theta=None):
        if self.imgs:
            x = reshape_time_series_x(x, self.traj_length, self.image_encoder, self.linear_layer)

        act = lambda x : F.relu(x)

        if theta is None:
            theta = self.theta_0

        n = len(theta)
        h = x
        count = 0
        for i in range(0, n // 2 - 1):
            h = act(F.linear(h, theta[2*i], bias = theta[2*i+1]))
            count += 1
        out = F.linear(h, theta[n-2], bias = theta[n-1])
        return out

    @property
    def num_params(self):
        return np.sum([np.prod(x.shape) for x in self.theta_0])
    
    def criterion(self, y_pred, y):
        if self.loss == 'mse':
            return torch.mean((y_pred-y)**(2))
        elif self.loss == 'cross_entropy':
            return nn.CrossEntropyLoss()(y_pred.squeeze(-1), y.long().squeeze(-1))
        elif self.loss == 'bce_loss':
            return nn.BCEWithLogitsLoss()(y_pred.squeeze(-1), y.squeeze(-1))

    def adapt(self, loss, return_grads=False):

        temp_weights = [w.clone() for w in self.theta_0]
        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, create_graph=True, retain_graph=True)

        if self.use_meta_curvature:
            theta_grad_s = self.meta_curvature.warp_grads(theta_grad_s)

        temp_weights=[w-self.learning_rate*g for w,g in zip(temp_weights,theta_grad_s)] #temporary update of weights

        if return_grads:
            return temp_weights, theta_grad_s

        return temp_weights





class MAMLClassifier(nn.Module):
    def __init__(self, dim, num_classes, use_meta_curvature, learnt_learning_rate, num_layers, hidden_dim):
        super().__init__()

        hidden = hidden_dim

        layers = []
        layers.append(nn.Linear(dim, hidden))
        for i in range(num_layers):
            layers.append(nn.ReLU())
            layers.append(nn.Linear(hidden, hidden))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden, num_classes))

        self.net = nn.Sequential(*layers)
        self.theta_0 = list(self.net.parameters())
        self.regressor = False

        batch_norms = [nn.BatchNorm1d(hidden, track_running_stats = False) for _ in range(num_layers+1)]
        self.batch_norms = nn.ModuleList(batch_norms)
        self.use_meta_curvature = use_meta_curvature

        if learnt_learning_rate:
            self.learning_rate = nn.Parameter(torch.rand(1))
        else:
            self.learning_rate = 0.01

        if use_meta_curvature:
            self.meta_curvature = Curvature(self.theta_0)


    def forward(self, x, theta=None):

        act = lambda x : F.relu(x)

        if theta is None:
            theta = self.theta_0

        n = len(theta)
        h = x
        count = 0
        for i in range(0, n // 2 - 1):
            h = act(self.batch_norms[count](F.linear(h, theta[2*i], bias = theta[2*i+1])))
            count += 1
        out = F.linear(h, theta[n-2], bias = theta[n-1])
        return out

    
    def criterion(self, y_pred, y, loss_scale=1.0):
        return nn.CrossEntropyLoss()(y_pred, y)


    def adapt(self, loss):

        temp_weights = [w.clone() for w in self.theta_0]

        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, retain_graph=True)

        if self.use_meta_curvature:
            theta_grad_s = self.meta_curvature.warp_grads(theta_grad_s)

        temp_weights=[w-self.learning_rate*g for w,g in zip(temp_weights,theta_grad_s)] #temporary update of weights

        return temp_weights


class ConvMaml(nn.Module):
    def __init__(self, channels, ways, use_meta_curvature, adapt_top_layers):
        super().__init__()

        self.use_meta_curvature = use_meta_curvature
        self.ways = ways

        self.adapt_top_layers = adapt_top_layers

        network_layers = 0
        self.network_layers = network_layers
 
        if channels == 1:
            self.h_shape = 1
            self.stride = 1
            self.filters = 32
        else:
            self.h_shape = 5
            self.stride = 1
            self.filters = 32

        hidden = 64
        hidden_fc = 64
        #hidden_input_dim = hidden * 4 * self.h_shape * self.h_shape #// 2
        hidden_input_dim = self.filters * self.h_shape * self.h_shape #// 2
        hidden_out = ways

        self.max_pool = nn.MaxPool2d(2)
        if self.adapt_top_layers:
            modules_body = [
                nn.Conv2d(channels, self.filters, 3, stride=1, padding=1), 
                nn.ReLU(),
                self.max_pool,
                ]
            for l in range(3):
                modules_body.append(nn.Conv2d(self.filters, self.filters, 3, stride=1, padding=1))
                modules_body.append(nn.BatchNorm2d(self.filters, track_running_stats=False))
                modules_body.append(nn.ReLU())
                modules_body.append(self.max_pool)

            modules_body.append(nn.Flatten())
            modules_body.append(nn.Linear(hidden_input_dim, hidden_fc))
            modules_body.append(nn.ReLU())
            
            self.net_body = nn.Sequential(*modules_body)
        
        if not self.adapt_top_layers:
            self.theta_shapes = [[self.filters, channels, 3, 3], [self.filters],
                                [self.filters, self.filters, 3, 3], [self.filters],
                                [self.filters, self.filters, 3, 3], [self.filters],
                                [self.filters, self.filters, 3, 3], [self.filters],
                                [hidden_fc, hidden_input_dim], [hidden_fc]]
        else:
            self.theta_shapes = []

        # self.theta_shapes.append([hidden_fc, hidden_fc])
        # self.theta_shapes.append([hidden_fc])
        # for l in range(network_layers):
        #     self.theta_shapes.append([hidden_fc, hidden_fc])
        #     self.theta_shapes.append([hidden_fc])
        self.theta_shapes.append([hidden_out, hidden_fc])
        self.theta_shapes.append([hidden_out])

 
        self.batch_norm1 = nn.BatchNorm2d(self.filters, track_running_stats=False)
        self.batch_norm2 = nn.BatchNorm2d(self.filters, track_running_stats=False)
        self.batch_norm3 = nn.BatchNorm2d(self.filters, track_running_stats=False)
        self.batch_norm4 = nn.BatchNorm2d(self.filters, track_running_stats=False)

        self.batch_norms = [self.batch_norm1, self.batch_norm2, self.batch_norm3, self.batch_norm4]
 
 
        #self.lr = nn.ParameterList([nn.Parameter(torch.tensor(1e-2))] * len(self.theta_shapes))
        self.lr = 0.4
 
        self.theta_0 = nn.ParameterList([nn.Parameter(torch.zeros(t_size)) for t_size in self.theta_shapes])
        for i in range(len(self.theta_0)):
            if self.theta_0[i].dim() > 1:
                torch.nn.init.kaiming_uniform_(self.theta_0[i])

        if use_meta_curvature:
            self.meta_curvature = Curvature(self.theta_0)
 
    def get_theta(self):
        return self.theta_0
 
    def criterion(self, y_pred, y):
        return nn.CrossEntropyLoss()(y_pred, y)

    @property
    def num_params(self):
        return np.sum([np.prod(x) for x in self.theta_shapes])

    # ## Deprecated
    # def forward_maxpool(self, x, theta=None):
 
    #     if theta is None:
    #         theta = self.theta_0
 
    #     h = self.max_pool(F.relu(self.batch_norm1(F.conv2d(x, theta[0], bias=theta[1], stride=1, padding=1))))
    #     h = self.max_pool(F.relu(self.batch_norm2(F.conv2d(h, theta[2], bias=theta[3], stride=1, padding=1))))
    #     h = self.max_pool(F.relu(self.batch_norm3(F.conv2d(h, theta[4], bias=theta[5], stride=1, padding=1))))
    #     h = self.max_pool(F.relu(self.batch_norm4(F.conv2d(h, theta[6], bias=theta[7], stride=1, padding=1))))
    #     h = h.contiguous()
    #     h = h.view(-1, (self.h_shape * self.h_shape * self.filters))
    #     y = F.linear(h, theta[8], bias=theta[9])
 
    #     return y

    def forward(self, x, theta=None):
        if theta is None:
            theta = self.theta_0

        if self.adapt_top_layers:
            h = self.net_body(x)
            init_theta_idx = 0

        else:
            h = self.max_pool(F.relu(self.batch_norm1(F.conv2d(x, theta[0], bias=theta[1], stride=1, padding=1))))
            for l in range(3):
                #print("L", l)
                h = self.max_pool(F.relu(self.batch_norms[l+1](F.conv2d(h, theta[2+2*l], bias=theta[3+2*l], stride=1, padding=1))))
            h = torch.flatten(h, 1)
            h = F.relu(F.linear(h, theta[2+2*3], bias=theta[3+2*3]))
            init_theta_idx = 2 + 2*4

        # h = F.relu(F.linear(h, theta[init_theta_idx], bias=theta[init_theta_idx + 1]))
        # for l in range(self.network_layers):
        #     h = F.relu(F.linear(h, theta[init_theta_idx + 2 +2*l], bias=theta[init_theta_idx + 3+2*l]))
        # out = F.linear(h, theta[init_theta_idx + 2 +2*self.network_layers], bias=theta[init_theta_idx + 3+2*self.network_layers])

        out = F.linear(h, theta[init_theta_idx], bias=theta[init_theta_idx + 1])
        
        return out
        #return self.head(out)

    def adapt(self, loss, return_grads=False):

        temp_weights = [w.clone() for w in self.theta_0]
        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, retain_graph=True)

        if self.use_meta_curvature:
            with torch.no_grad():
                temp_weights2 = [w.clone() for w in self.theta_0]
                theta_grad_s_pre = theta_grad_s
                temp_weights_pre = [w - self.lr * g for w, g in zip(temp_weights2, theta_grad_s_pre)]

            theta_grad_s = self.meta_curvature.warp_grads(theta_grad_s)

        temp_weights=[w-self.lr*g for w,g in zip(temp_weights,theta_grad_s)] #temporary update of weights

        if return_grads:
            return temp_weights, theta_grad_s

        if self.use_meta_curvature:
            return temp_weights #, temp_weights_pre
        else:
            return temp_weights #, theta_grad_s



# Some hinge loss
"""
            hinge_loss = torch.Tensor([0]).to(device)
            if self.use_img:
                y_random = y[torch.randperm(batch_size)]
                distance = torch.norm(y - y_random, p=2, dim=-1)

                hinge_loss = torch.max(torch.zeros_like(distance).to(device), torch.ones_like(distance).to(device) - distance).mean()

"""