import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from opelab.core.baseline import Baseline
from opelab.core.data import to_numpy, Normalization


class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super(QNetwork, self).__init__()
        layers = []
        input_dim = state_dim + action_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(input_dim, h))
            layers.append(nn.ReLU())
            input_dim = h
        layers.append(nn.Linear(input_dim, 1))
        self.net = nn.Sequential(*layers)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, state, action):
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).float().to(action.device)
        if isinstance(action, np.ndarray):
            action = torch.from_numpy(action).float().to(state.device)
            
        x = torch.cat([state, action], dim=-1)
        return self.net(x)


class FQE(Baseline):
    def __init__(self, state_dim, action_dim, device='cuda', gamma=0.99, epochs=500,
                 batch_size=2048, lr=1e-5, tau=0.05, target_update_freq=30,
                 preprocess_once=True):  # Flag to process data only once
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.tau = tau
        self.device = device
        self.target_update_freq = target_update_freq
        self.preprocess_once = preprocess_once

        self.q_net1 = QNetwork(state_dim, action_dim).to(device)
        self.q_net2 = QNetwork(state_dim, action_dim).to(device)

        self.target_q_net1 = QNetwork(state_dim, action_dim).to(device)
        self.target_q_net2 = QNetwork(state_dim, action_dim).to(device)
        self.target_q_net1.load_state_dict(self.q_net1.state_dict())
        self.target_q_net2.load_state_dict(self.q_net2.state_dict())

        self.optimizer = optim.Adam(
            list(self.q_net1.parameters()) + list(self.q_net2.parameters()), 
            lr=self.lr
        )
        self.criterion = nn.MSELoss(reduction='sum')  # Sum for more stable gradients
        
        self.data = None
        self.processed_data = None
        self.initial_states_tensor = None
        self.behavior_policy = None
        self.target_policy = None

    def load_data(self, data):
        """Load data and preprocess if flag is set"""
        self.data = data
        self.processed_data = None
        self.initial_states_tensor = None

    def soft_update(self, net, target_net):
        """Soft update target network"""
        for param, target_param in zip(net.parameters(), target_net.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def _prepare_data(self, data, target_policy, behavior_policy):
        """Process data once if flag is set, or every time if not"""
        if self.preprocess_once and self.processed_data is not None:
            return self.processed_data
        
        print("Preprocessing data...")
        states, _, actions, next_states, _, rewards, _, terminals = to_numpy(
            data, target_policy, behavior_policy,
            normalization=Normalization.STD, return_terminals=True
        )

        def to_tensor(x):
            return torch.tensor(x, dtype=torch.float32, device=self.device)
            
        processed_data = (
            to_tensor(states),
            to_tensor(actions),
            to_tensor(next_states),
            to_tensor(rewards),
            to_tensor(terminals)
        )
        
        if self.preprocess_once:
            self.processed_data = processed_data
            self.target_policy = target_policy
            self.behavior_policy = behavior_policy
            
            print("Preprocessing initial states...")
            initial_states = np.array([ep['states'][0] for ep in data])
            self.initial_states_tensor = to_tensor(initial_states)
            
        return processed_data

    def evaluate(self, data, target_policy, behavior_policy, gamma=0.99, reward_estimator=None):
        state_dim = self.state_dim
        action_dim = self.action_dim
        device = self.device
        
        self.q_net1 = QNetwork(state_dim, action_dim).to(device)
        self.q_net2 = QNetwork(state_dim, action_dim).to(device)

        self.target_q_net1 = QNetwork(state_dim, action_dim).to(device)
        self.target_q_net2 = QNetwork(state_dim, action_dim).to(device)
        self.target_q_net1.load_state_dict(self.q_net1.state_dict())
        self.target_q_net2.load_state_dict(self.q_net2.state_dict())

        self.optimizer = optim.Adam(
            list(self.q_net1.parameters()) + list(self.q_net2.parameters()), 
            lr=self.lr
        )

        if self.preprocess_once:
            if self.processed_data is None:
                states, actions, next_states, rewards, terminals = self._prepare_data(
                    self.data, target_policy, behavior_policy
                )
            else:
                states, actions, next_states, rewards, terminals = self.processed_data
                self.target_policy = target_policy
                self.behavior_policy = behavior_policy
        else:
            states, actions, next_states, rewards, terminals = self._prepare_data(
                self.data, target_policy, behavior_policy
            )
        
        dataset_size = states.shape[0]
        
        indices = np.arange(dataset_size)
        
        self.q_net1.train()
        self.q_net2.train()
        
        pbar = tqdm(range(self.epochs), desc="Training FQE")
        for epoch in pbar:
            np.random.shuffle(indices)
            epoch_loss = 0.0
            num_batches = 0
            
            for i in range(0, dataset_size, self.batch_size):
                batch_indices = indices[i:i+self.batch_size]
                s = states[batch_indices]
                a = actions[batch_indices]
                ns = next_states[batch_indices]
                r = rewards[batch_indices]
                d = terminals[batch_indices]
                
                with torch.no_grad():
                    na = target_policy.sample_tensor(ns, deterministic=True)
                    
                    q1_next = self.target_q_net1(ns, na)
                    q2_next = self.target_q_net2(ns, na)
                    q_target = torch.min(q1_next, q2_next)
                    y = r + gamma * (1 - d) * q_target
                
                q1_val = self.q_net1(s, a)
                q2_val = self.q_net2(s, a)
                
                loss1 = self.criterion(q1_val, y)
                loss2 = self.criterion(q2_val, y)
                loss = (loss1 + loss2) / s.shape[0] 
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item() * s.shape[0]
                num_batches += 1
                
                if num_batches % self.target_update_freq == 0:
                    self.soft_update(self.q_net1, self.target_q_net1)
                    self.soft_update(self.q_net2, self.target_q_net2)
            
            avg_loss = epoch_loss / dataset_size
            pbar.set_postfix({"loss": f"{avg_loss:.6f}"})
        
        self.q_net1.eval()
        self.q_net2.eval()
        
        if self.preprocess_once and self.initial_states_tensor is not None:
            initial_states_tensor = self.initial_states_tensor
        else:
            initial_states = np.array([ep['states'][0] for ep in data])
            initial_states_tensor = torch.tensor(initial_states, dtype=torch.float32, device=self.device)
        
        print("Evaluating policy on initial states...")
        values = []
        
        with torch.no_grad():
            actions = target_policy.sample_tensor(initial_states_tensor, deterministic=True)
            
            q1_vals = self.q_net1(initial_states_tensor, actions)
            q2_vals = self.q_net2(initial_states_tensor, actions)
            q_vals = torch.min(q1_vals, q2_vals)
            
            mean_value = q_vals.mean().item()
            values = mean_value
            print(f"Mean value: {mean_value:.6f}")
        
        return values