import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from opelab.core.baseline import Baseline
from opelab.core.policy import Policy
from opelab.core.data import DataType, to_numpy
from collections import deque
import random

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, states, actions):
        return self.net(torch.cat([states, actions], -1)).squeeze(-1)

class ValueNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, states):
        return self.net(states).squeeze(-1)

class DRV2(Baseline):
    def __init__(self, state_dim, action_dim, gamma=0.99,
                 hidden_dim=64, lr=1e-3, batch_size=64, 
                 epochs=50, tau=0.005, device='cuda:0'):
        super().__init__()
        self.device = device
        self.gamma = gamma
        
        print("Using device:", self.device)
        
        # Q-networks with target networks
        self.q_net = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.q_target = QNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.q_target.load_state_dict(self.q_net.state_dict())
        
        # Value network with target network
        self.v_net = ValueNetwork(state_dim, hidden_dim).to(device)
        self.v_target = ValueNetwork(state_dim, hidden_dim).to(device)
        self.v_target.load_state_dict(self.v_net.state_dict())
        
        # Optimizers
        self.q_optim = optim.Adam(self.q_net.parameters(), lr=lr)
        self.v_optim = optim.Adam(self.v_net.parameters(), lr=lr)
        
        self.batch_size = batch_size
        self.epochs = epochs
        self.tau = tau

    def _polyak_update(self, target, source):
        """Soft update target networks using Polyak averaging"""
        for t, s in zip(target.parameters(), source.parameters()):
            t.data.copy_(self.tau*s.data + (1-self.tau)*t.data)

    def _process_trajectories(self, data):
        """Convert trajectories to transitions with terminal state handling"""
        transitions = []
        for traj in data:
            traj_len = len(traj['rewards'])
            for t in range(traj_len):
                state = traj['states'][t]
                action = traj['actions'][t]
                reward = traj['rewards'][t]
                next_state = traj['states'][t+1] if t < traj_len-1 else state
                done = t == traj_len-1  # Last step is terminal
                
                transitions.append((state, action, reward, next_state, done))
        return transitions

    def evaluate(self, data, target_policy, behavior_policy, gamma=None, reward_estimator=None) -> float:
        if gamma is None:
            gamma = self.gamma
        
        # Process trajectories into transitions
        transitions = self._process_trajectories(data)
        states, actions, rewards, next_states, dones = zip(*transitions)
        
        # Convert to tensors
        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.FloatTensor(np.array(actions)).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.BoolTensor(np.array(dones)).to(self.device)

        # Compute behavior probabilities
        with torch.no_grad():
            behavior_probs = torch.stack([
                torch.FloatTensor([behavior_policy.prob(s.cpu().numpy(), a.cpu().numpy())])
                for s, a in zip(states, actions)
            ]).squeeze().to(self.device).clamp(min=1e-6)

        # Compute importance weights
        with torch.no_grad():
            target_probs = torch.stack([
                torch.FloatTensor([target_policy.prob(s.cpu().numpy(), a.cpu().numpy())])
                for s, a in zip(states, actions)
            ]).squeeze().to(self.device).clamp(min=1e-6)
            rhos = target_probs / behavior_probs

        # Create dataset
        dataset = TensorDataset(states, actions, rewards, next_states, dones, rhos)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        
        # Training phase
        with tqdm(total=self.epochs, desc="Training Networks") as pbar:
            for epoch in range(self.epochs):
                epoch_q_loss = 0.0
                epoch_v_loss = 0.0
                
                for batch in dataloader:
                    s, a, r, ns, d, rho = batch
                    
                    # Train Q-network
                    with torch.no_grad():
                        # Mask terminal states
                        valid_next = ~d
                        next_actions = torch.stack([
                            torch.FloatTensor(target_policy.sample(ns_i.cpu().numpy()))
                            for ns_i, valid in zip(ns, valid_next)
                            if valid
                        ]).to(self.device)
                        
                        # Compute target with proper terminal handling
                        v_next = torch.zeros_like(r)
                        v_next[valid_next] = self.v_target(ns[valid_next])
                        q_target = r + gamma * v_next
                    
                    q_pred = self.q_net(s, a)
                    q_loss = (rho * (q_pred - q_target)**2).mean()
                    
                    self.q_optim.zero_grad()
                    q_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
                    self.q_optim.step()
                    
                    # Train V-network
                    with torch.no_grad():
                        # Double Q-learning
                        q1 = self.q_net(s, a)
                        q2 = self.q_target(s, a)
                        v_target = 0.5 * (q1 + q2)
                    
                    v_pred = self.v_net(s)
                    v_loss = (rho * (v_pred - v_target)**2).mean()
                    
                    self.v_optim.zero_grad()
                    v_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.v_net.parameters(), 1.0)
                    self.v_optim.step()
                    
                    # Update target networks
                    self._polyak_update(self.q_target, self.q_net)
                    self._polyak_update(self.v_target, self.v_net)
                    
                    epoch_q_loss += q_loss.item()
                    epoch_v_loss += v_loss.item()
                
                # Update progress bar
                pbar.set_postfix({
                    'Q Loss': epoch_q_loss/len(dataloader),
                    'V Loss': epoch_v_loss/len(dataloader)
                })
                pbar.update(1)
        
        # Evaluation phase
        total_dr = 0.0
        num_trajs = len(data)
        
        with tqdm(data, desc="Computing DR Estimates") as traj_pbar:
            for traj in traj_pbar:
                states = traj['states']
                actions = traj['actions']
                rewards = traj['rewards']
                traj_len = len(rewards)
                
                # Initialize with V(s0)
                with torch.no_grad():
                    s0 = torch.FloatTensor(states[0]).to(self.device)
                    v_s0 = self.v_net(s0).item()
                
                dr_traj = v_s0
                cum_rho = 1.0
                discount = 1.0
                
                for t in range(traj_len):
                    state = states[t]
                    action = actions[t]
                    reward = rewards[t]
                    is_terminal = t == traj_len - 1
                    
                    # Compute importance ratio
                    target_prob = target_policy.prob(state, action)
                    behavior_prob = behavior_policy.prob(state, action)
                    rho_t = target_prob / max(behavior_prob, 1e-6)
                    cum_rho *= rho_t
                    
                    # Compute TD error
                    with torch.no_grad():
                        s_tensor = torch.FloatTensor(state).to(self.device).unsqueeze(0)
                        a_tensor = torch.FloatTensor(action).to(self.device).unsqueeze(0)
                        q_val = self.q_net(s_tensor, a_tensor).item()
                        
                        if is_terminal:
                            v_next = 0.0
                        else:
                            ns = states[t+1]
                            ns_tensor = torch.FloatTensor(ns).to(self.device).unsqueeze(0)
                            v_next = self.v_target(ns_tensor).item()
                    
                    td_error = reward + gamma * v_next - q_val
                    dr_traj += discount * cum_rho * td_error
                    discount *= gamma
                
                total_dr += dr_traj
                traj_pbar.set_postfix({'Current DR Estimate': total_dr/(traj_pbar.n+1)})
        
        return total_dr / num_trajs