import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import Swin_V2_T_Weights, swin_v2_t
from torchvision import datasets, transforms, models

class Switch_policy(torch.nn.Module):
    def __init__(self):
        super(Switch_policy, self).__init__()
        # out_size = int(640 / 32.) * int(480 / 32.) # 300
        self.main = models.resnet18(pretrained=False)
        self.main.conv1= nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)
        in_features = self.main.fc.in_features
        self.main.fc = nn.Linear(in_features, 4)
        self.critic_linear = nn.Linear(4+1, 1)
    def forward(self, x, extras):
        x = self.main(x)
        x = torch.cat((x, extras), 1)
        x = nn.Sigmoid()(self.critic_linear(x))
        return x



class Classifier1(nn.Module):

    def __init__(self) :

        super(Classifier1, self).__init__()

        weights = Swin_V2_T_Weights.DEFAULT
        self.model_ft = swin_v2_t(weights=weights)
        for param in self.model_ft.parameters():
            param.requires_grad = False

        self.linear = nn.Linear(1000, 7)

    def forward(self, input):
        x = self.model_ft(input)
        y = F.softmax(self.linear(x), dim=1)

        return y

def get_grid(pose, grid_size, device):
    """
    Input:
        `pose` FloatTensor(bs, 3)
        `grid_size` 4-tuple (bs, _, grid_h, grid_w)
        `device` torch.device (cpu or gpu)
    Output:
        `rot_grid` FloatTensor(bs, grid_h, grid_w, 2)
        `trans_grid` FloatTensor(bs, grid_h, grid_w, 2)

    """
    pose = pose.float()
    x = pose[:, 0]
    y = pose[:, 1]
    t = pose[:, 2]

    bs = x.size(0)
    t = t * np.pi / 180.
    cos_t = t.cos()
    sin_t = t.sin()

    theta11 = torch.stack([cos_t, -sin_t,
                           torch.zeros(cos_t.shape).float().to(device)], 1)
    theta12 = torch.stack([sin_t, cos_t,
                           torch.zeros(cos_t.shape).float().to(device)], 1)
    theta1 = torch.stack([theta11, theta12], 1)

    theta21 = torch.stack([torch.ones(x.shape).to(device),
                           -torch.zeros(x.shape).to(device), x], 1)
    theta22 = torch.stack([torch.zeros(x.shape).to(device),
                           torch.ones(x.shape).to(device), y], 1)
    theta2 = torch.stack([theta21, theta22], 1)

    rot_grid = F.affine_grid(theta1, torch.Size(grid_size))
    trans_grid = F.affine_grid(theta2, torch.Size(grid_size))

    return rot_grid, trans_grid

def get_grid_3d(pose, grid_size, device):
    """
    Input:
        `pose` FloatTensor(bs, 3)
        `grid_size` 5-tuple (bs, _, grid_h, grid_w, height)
        `height` int
        `device` torch.device (cpu or gpu)
    Output:
        `rot_grid` FloatTensor(bs, grid_h, grid_w, height, 2)
        `trans_grid` FloatTensor(bs, grid_h, grid_w, height, 2)

    """
    pose = pose.float()
    x = pose[:, 0]
    y = pose[:, 1]
    t = pose[:, 2]

    bs = x.size(0)
    t = t * np.pi / 180.
    cos_t = t.cos()
    sin_t = t.sin()

    theta11 = torch.stack([cos_t, -sin_t,
                           torch.zeros(cos_t.shape).float().to(device), torch.zeros(cos_t.shape).float().to(device)], 1)
    theta12 = torch.stack([sin_t, cos_t,
                           torch.zeros(cos_t.shape).float().to(device), torch.zeros(cos_t.shape).float().to(device)], 1)
    theta13 = torch.stack([torch.zeros(cos_t.shape).float().to(device), torch.zeros(cos_t.shape).float().to(device),
                            torch.ones(x.shape).to(device), torch.zeros(cos_t.shape).float().to(device)], 1)
    theta14 = torch.stack([torch.zeros(cos_t.shape).float().to(device), torch.zeros(cos_t.shape).float().to(device),
                            torch.zeros(cos_t.shape).float().to(device), torch.ones(x.shape).to(device)], 1)
    theta1 = torch.stack([theta11, theta12, theta13], 1)

    theta21 = torch.stack([torch.ones(x.shape).to(device),
                           -torch.zeros(x.shape).to(device), -torch.zeros(x.shape).to(device), x], 1)
    theta22 = torch.stack([torch.zeros(x.shape).to(device),
                           torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device), y], 1)
    theta23 = torch.stack([torch.zeros(x.shape).to(device), torch.zeros(x.shape).to(device),
                           torch.ones(x.shape).to(device), torch.zeros(x.shape).to(device)], 1)
    theta24 = torch.stack([torch.zeros(x.shape).to(device), torch.zeros(x.shape).to(device),
                           torch.zeros(x.shape).to(device), torch.ones(x.shape).to(device)], 1)

    theta2 = torch.stack([theta21, theta22, theta23], 1)

    rot_grid = F.affine_grid(theta1, torch.Size(grid_size))
    trans_grid = F.affine_grid(theta2, torch.Size(grid_size))

    return rot_grid, trans_grid



class ChannelPool(nn.MaxPool1d):
    def forward(self, x):
        n, c, w, h = x.size()
        x = x.view(n, c, w * h).permute(0, 2, 1)
        x = x.contiguous()
        pooled = F.max_pool1d(x, c, 1)
        _, _, c = pooled.size()
        pooled = pooled.permute(0, 2, 1)
        return pooled.view(n, c, w, h)


# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/utils.py#L32
class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
        if x.dim() == 2:
            bias = self._bias.t().view(1, -1)
        else:
            bias = self._bias.t().view(1, -1, 1, 1)

        return x + bias


# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py#L10
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py#L82
class NNBase(nn.Module):

    def __init__(self, recurrent, recurrent_input_size, hidden_size):

        super(NNBase, self).__init__()
        self._hidden_size = hidden_size
        self._recurrent = recurrent

        if recurrent:
            self.gru = nn.GRUCell(recurrent_input_size, hidden_size)
            nn.init.orthogonal_(self.gru.weight_ih.data)
            nn.init.orthogonal_(self.gru.weight_hh.data)
            self.gru.bias_ih.data.fill_(0)
            self.gru.bias_hh.data.fill_(0)

    @property
    def is_recurrent(self):
        return self._recurrent

    @property
    def rec_state_size(self):
        if self._recurrent:
            return self._hidden_size
        return 1

    @property
    def output_size(self):
        return self._hidden_size

    def _forward_gru(self, x, hxs, masks):
        if x.size(0) == hxs.size(0):
            x = hxs = self.gru(x, hxs * masks[:, None])
        else:
            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
            N = hxs.size(0)
            T = int(x.size(0) / N)

            # unflatten
            x = x.view(T, N, x.size(1))

            # Same deal with masks
            masks = masks.view(T, N, 1)

            outputs = []
            for i in range(T):
                hx = hxs = self.gru(x[i], hxs * masks[i])
                outputs.append(hx)

            # x is a (T, N, -1) tensor
            x = torch.stack(outputs, dim=0)
            # flatten
            x = x.view(T * N, -1)

        return x, hxs
