import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math

from solver.eq_LPsolver import CoarseCorrelatedEquilibriumLPSolver

from utils import sample

class LSVI_UCB_GENSUM(object): 

    def __init__(
        self,
        obs_dim,
        state_dim,
        num_actions,
        horizon,
        alpha,
        device,
        rep_learners,
        lamb = 1,
        recent_size=0,
    ):

        self.obs_dim = obs_dim
        self.state_dim = state_dim
        self.num_action = num_actions
        self.action_dim = num_actions ** 2
        self.horizon = horizon

        self.feature_dim = state_dim * self.action_dim

        self.device = device

        self.rep_learners = rep_learners

        self.lamb = lamb
        self.alpha = alpha

        self.recent_size = recent_size

        self.W1 = torch.rand((self.horizon, self.feature_dim)).to(self.device)
        self.W2 = torch.rand((self.horizon, self.feature_dim)).to(self.device)
        self.Sigma_invs = torch.zeros((self.horizon, self.feature_dim, self.feature_dim)).to(self.device)

        self.Q_max = torch.tensor(self.horizon)

    def Q_values(self, obs, h):
        Qs1 = torch.zeros((len(obs),self.action_dim)).to(self.device)
        Qs2 = torch.zeros((len(obs),self.action_dim)).to(self.device)
        for a in range(self.action_dim):
            actions = torch.zeros((len(obs),self.action_dim)).to(self.device)
            actions[:,a] = 1
            with torch.no_grad():
                feature = self.rep_learners[h].phi(obs,actions,tau=0.1)
            #print(self.rep_learners[h].feature_dim)
            #print(feature.shape)
            Q_est1 = torch.matmul(feature, self.W1[h].to(self.device)) 
            Q_est2 = torch.matmul(feature, self.W2[h].to(self.device)) 
            ucb = torch.sqrt(torch.sum(torch.matmul(feature, self.Sigma_invs[h].to(self.device))*feature, 1))
            
            Qs1[:,a] = torch.minimum(Q_est1 + self.alpha * ucb, self.Q_max)
            Qs2[:,a] = torch.minimum(Q_est2 + self.alpha * ucb, self.Q_max)

        return Qs1, Qs2

    def solve_cce(self, Q1, Q2):
        Q1 = Q1.reshape(self.num_action, self.num_action)
        Q2 = Q2.reshape(self.num_action, self.num_action)
        _, _, ne, v1, v2 = CoarseCorrelatedEquilibriumLPSolver(Q1,Q2)
        return ne, v1, v2

    def act_batch(self, obs, h, stochastic=True):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            Qs1, Qs2 = self.Q_values(obs, h)
            Qs1 = Qs1.cpu().numpy()
            Qs2 = Qs2.cpu().numpy()
        
        nes = []
        for i in range(len(Qs1)):
            ne, _, _ = self.solve_cce(Qs1[i], Qs2[i])
            nes.append(ne)
        nes = np.array(nes)

        #print("ours!!!!!!!!!!")
        #print(nes)
        #print(nes)
        #print(self.action_dim)
        #print("!!!!")
        if stochastic:
            action = sample(nes, np.arange(self.action_dim), len(obs))
        else:
            action = np.argmax(nes, -1)

        marg_actions = []
        for a in action:
            a1 = np.floor(a/self.num_action)
            a1 = np.array(a1, dtype=np.int32)
            #a1 = int(a/self.num_action)
            a2 = np.mod(a, self.num_action)
            marg_actions.append([a1,a2])

        return np.array(marg_actions)

    def act(self, obs, h):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            Qs = self.Q_values(obs, h)
            action = torch.argmax(Qs, dim=1)

        return action.cpu().data.numpy().flatten()

    def update(self, buffers):
        assert len(buffers) == self.horizon

        for h in range(self.horizon)[::-1]:
            if self.recent_size > 0:
                obses, actions, rewards1, rewards2, next_obses = buffers[h].get_full(device=self.device, recent_size=self.recent_size)
            else:
                obses, actions, rewards1, rewards2, next_obses = buffers[h].get_full(device=self.device)
            
            with torch.no_grad():
                feature = self.rep_learners[h].phi(obses,actions, tau=0.1)
            Sigma = torch.matmul(feature.T, feature) + self.lamb * torch.eye(self.feature_dim).to(self.device)
            self.Sigma_invs[h] = torch.inverse(Sigma)

            if h == self.horizon - 1:
                target_Q1 = rewards1
                target_Q2 = rewards2
            else:
                Q_prime1, Q_prime2 = self.Q_values(next_obses, h+1)
                Q_prime1 = Q_prime1.cpu().numpy()
                Q_prime2 = Q_prime2.cpu().numpy()
                ne_v1 = []
                ne_v2 = []
                for q1, q2 in zip(Q_prime1, Q_prime2):
                    _, v1, v2 = self.solve_cce(q1,q2)
                    ne_v1.append(v1)
                    ne_v2.append(v2)
                ne_v1 = torch.as_tensor(ne_v1, dtype=torch.float, device=self.device).unsqueeze(-1)
                ne_v2 = torch.as_tensor(ne_v2, dtype=torch.float, device=self.device).unsqueeze(-1)
                target_Q1 = rewards1 + ne_v1
                target_Q2 = rewards2 + ne_v1

            self.W1[h] = torch.matmul(self.Sigma_invs[h].to(self.device), torch.sum(feature * target_Q1, 0)) 
            self.W2[h] = torch.matmul(self.Sigma_invs[h].to(self.device), torch.sum(feature * target_Q2, 0))             

    def save_weight(self, path, n=None):
        if n == None:
            for h in range(self.horizon):
                torch.save(self.W1[h],"{}/W1_{}.pth".format(path,str(h)))
                torch.save(self.W2[h],"{}/W3_{}.pth".format(path,str(h)))
                torch.save(self.Sigma_invs[h], "{}/Sigma_{}.pth".format(path,str(h)))
        else:
            for h in range(self.horizon):
                torch.save(self.W1[h],"{}/W1_{}_{}.pth".format(path,str(h),n))
                torch.save(self.W2[h],"{}/W3_{}_{}.pth".format(path,str(h),n))
                torch.save(self.Sigma_invs[h], "{}/Sigma_{}_{}.pth".format(path,str(h),n))

    def load_weight(self, path):
        for h in range(self.horizon):
            self.W1[h] = torch.load("{}/W1_{}.pth".format(path,str(h)))
            self.W2[h] = torch.load("{}/W3_{}.pth".format(path,str(h)))
            self.Sigma_invs[h] = torch.load("{}/Sigma_{}.pth".format(path,str(h)))
