import numpy as np
import os
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import uuid
from dataclasses import dataclass
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from typing import Any, Dict, List, Optional
from torchvision import models, transforms
from PIL import Image


@dataclass
class TrainConfig:
    # Experiment
    device: str = "cpu"
    env: str = "swimmer"  # OpenAI gym environment name. Choose from 'swimmer', 'reacher' and 'hopper'.
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_freq: int = 100  # How often (epochs) we evaluate
    max_epochs: int = int(1e6)  # How many epochs to run
    checkpoints_path: Optional[str] = None  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    batch_size: int = 256  # Batch size for all networks
    data_size: int = 1000  # Number of data points to use

    arch: str = '256-R-256-R-256-R|T'  # Actor architecture
    optimizer: str = 'sgd'
    lamW: float = 5e-2
    lr: float = 1e-2

    data_folder: str = '/dataset/mujoco'

    whitening: str = 'none'  # Choose from 'none', 'whiten', 'normalize'
    single_task: Optional[int] = None  # Choose from 0, 1, 2 for single task, None for multitask


class ZCA:
    def __init__(self):
        self.mean = None
        self.sigma_sqrt = None
        self.sigma_sqrt_inv = None

    def fit_transform(self, Y):
        """
        Function to compute ZCA whitening transformation.
        INPUT:  Y: [N x M] matrix.
            Rows: Observations
            Columns: Variables
        OUTPUT: Y_zca: [N x M] matrix - whitened data
        """
        self.fit(Y)
        return self.transform(Y)

    def fit(self, Y):
        """
        Compute necessary statistics from Y
        """
        # Center the data
        self.mean = np.mean(Y, axis=0, keepdims=True)
        Y_centered = Y - self.mean

        # Compute covariance matrix and its square root
        M = Y.shape[0]
        Sigma = (Y_centered.T @ Y_centered) / M

        # Compute Sigma^(1/2) using eigendecomposition
        eig_vals, eig_vecs = np.linalg.eigh(Sigma)
        sqrt_eig_vals = np.sqrt(eig_vals)
        self.sigma_sqrt = eig_vecs @ np.diag(sqrt_eig_vals) @ eig_vecs.T

        # Compute [Sigma^(1/2)]^(-1)
        self.sigma_sqrt_inv = eig_vecs @ np.diag(1.0/sqrt_eig_vals) @ eig_vecs.T

        return self

    def transform(self, Y):
        """
        Apply ZCA whitening: Y_zca = [Sigma^(1/2)]^(-1)(Y - Ybar)
        """
        return (Y - self.mean) @ self.sigma_sqrt_inv.T

    def inverse_transform(self, Y_zca):
        """
        De-whiten the data: Y = [Sigma^(1/2)]Y_zca + Ybar
        """
        return Y_zca @ self.sigma_sqrt.T + self.mean


class Normalize:
    def __init__(self):
        self.mean = None
        self.V_sqrt = None
        self.V_sqrt_inv = None

    def fit_transform(self, Y):
        """
        Function to compute normalization transformation.
        INPUT:  Y: [N x M] matrix.
            Rows: Observations
            Columns: Variables
        OUTPUT: Y_std: [N x M] matrix - standardized data
        """
        self.fit(Y)
        return self.transform(Y)

    def fit(self, Y):
        """
        Compute necessary statistics from Y
        """
        # Center the data
        self.mean = np.mean(Y, axis=0, keepdims=True)
        Y_centered = Y - self.mean

        # Compute variance matrix V (diagonal matrix of variances)
        variances = np.var(Y_centered, axis=0)

        # Create diagonal matrices V^(1/2) and V^(-1/2)
        sqrt_variances = np.sqrt(variances)
        self.V_sqrt = np.diag(sqrt_variances)
        self.V_sqrt_inv = np.diag(1.0/sqrt_variances)

        return self

    def transform(self, Y):
        """
        Apply normalization: Y_std = V^(-1/2)(Y - Ybar)
        """
        return (Y - self.mean) @ self.V_sqrt_inv

    def inverse_transform(self, Y_std):
        """
        De-standardize the data: Y = V^(1/2)Y_std + Ybar
        """
        return Y_std @ self.V_sqrt + self.mean


def compute_metrics(y, yhat, whitening, device):
    if whitening is not None:
        yhat = whitening.inverse_transform(yhat.cpu().numpy())
        yhat = torch.tensor(yhat, dtype=torch.float32, device=device)

    return F.mse_loss(y, yhat).item()


def set_seed(
        seed: int, env=None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


class Buffer(Dataset):
    def __init__(
            self,
            data_folder: str,
            env: str,
            split: str,
            data_size: int,
            whitening: str = 'none',
            device: str = "cpu",
            single_task: Optional[int] = None,
            is_carla: bool = False
    ):
        self.size = 0
        self.state_dim = 0
        self.action_dim = 0
        self.is_carla = is_carla

        if self.is_carla:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),  # Resizing images to 224x224
                transforms.ToTensor(),  # Converting images to tensors
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalizing with ImageNet stats
            ])

        self.states, self.actions = None, None
        self.single_task = single_task
        self._load_dataset(data_folder, env, split, data_size)

        self.original_actions = self.actions.copy()

        if whitening == 'whiten':
            self.whitening = ZCA()
            self.actions = self.whitening.fit_transform(self.actions)
        elif whitening == 'normalize':
            self.whitening = Normalize()
            self.actions = self.whitening.fit_transform(self.actions)
        else:
            self.whitening = None

        self.device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        if self.is_carla and len(data.shape) == 3:  # If it's an image from Carla
            # Assuming data is in (H, W, C) format and values are in [0, 255]
            data = Image.fromarray(data.astype(np.uint8))
            return self.transform(data).to(self.device)

        return torch.tensor(data, dtype=torch.float32, device=self.device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def _load_dataset(self, data_folder: str, env: str, split: str, data_size: int):
        file_name = '%s_%s.pkl' % (env, split)
        file_path = os.path.join(data_folder, file_name)
        try:
            with open(file_path, 'rb') as file:
                dataset = pickle.load(file)
                if split == 'test':
                    data_size = min(data_size // 5, dataset['observations'].shape[0])
                self.size = data_size
                self.states = dataset['observations'][:self.size, :]

                if self.single_task is not None:
                    self.actions = dataset['actions'][:self.size, self.single_task:self.single_task + 1]
                else:
                    self.actions = dataset['actions'][:self.size, :]
            print('Successfully load dataset from: ', file_path)
        except Exception as e:
            print(e)

        if len(self.states.shape) == 4:
            self.state_dim = self.states.shape[1:]  # (H, W, C)
        else:
            self.state_dim = self.states.shape[1]
        self.action_dim = self.actions.shape[1]

        print(f"Dataset size: {self.size}; State Dim: {self.state_dim}; Action_Dim: {self.action_dim}.")

    def get_state_dim(self):
        return self.state_dim

    def get_action_dim(self):
        return self.action_dim

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        states = self.states[idx]
        actions = self.actions[idx]
        original_actions = self.original_actions[idx]
        return {
            'states': self._to_tensor(states),
            'actions': self._to_tensor(actions),
            'original_actions': self._to_tensor(original_actions)
        }


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim: int, max_action: float = 1.0, arch: str = '256-R-256-R|T'):
        super(Actor, self).__init__()

        if "resnet" in arch:
            self.feature_map = models.resnet18(weights=None)
            feature_dim = self.feature_map.fc.in_features
            self.feature_map.fc = nn.Identity()
            self.W = nn.Linear(feature_dim, action_dim)
        else:
            arch, use_bias = arch.split('|')
            arch = arch.split('-')
            use_bias = True if use_bias == 'T' else False

            in_dim = state_dim
            module_list = []
            for i, layer in enumerate(arch):
                if layer == 'R':
                    module_list.append(nn.ReLU())
                else:
                    out_dim = int(layer)
                    module_list.append(nn.Linear(in_dim, out_dim))
                    in_dim = out_dim

            self.feature_map = nn.Sequential(*module_list)
            self.W = nn.Linear(in_dim, action_dim, bias=use_bias)

        self.max_action = max_action

    def get_feature(self, state: torch.Tensor):
        return self.feature_map(state)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        H = self.get_feature(state)
        return self.W(H)

    def project(self, feature):
        return self.W(feature)


class BC:
    def __init__(
            self,
            actor: nn.Module,
            actor_optimizer: torch.optim.Optimizer,
            lamW: float,
            device: str = "cpu",
    ):
        self.actor = actor
        self.actor.train()

        self.actor_optimizer = actor_optimizer

        self.total_it = 0
        self.lamW = lamW
        self.device = device

    def train(self, batch) -> Dict[str, float]:
        self.total_it += 1

        states, actions = batch['states'], batch['actions']

        preds = self.actor(states)

        mse_loss = 0.5 * F.mse_loss(preds, actions)

        reg_loss = 0
        for param in self.actor.parameters():
            reg_loss += torch.norm(param) ** 2
        reg_loss = 0.5 * self.lamW * reg_loss

        train_loss = mse_loss + reg_loss

        self.actor_optimizer.zero_grad()
        train_loss.backward()
        self.actor_optimizer.step()

    @torch.no_grad()
    def NC_eval(self, dataloader):
        self.actor.eval()
        y = torch.empty((0,), device=self.device)
        yhat = torch.empty((0,), device=self.device)

        for i, batch in enumerate(dataloader):
            states, original_actions = batch['states'], batch['original_actions']
            features = self.actor.get_feature(states)
            preds = self.actor.project(features)

            y = torch.cat((y, original_actions), dim=0)
            yhat = torch.cat((yhat, preds), dim=0)

        mse = compute_metrics(y=y, yhat=yhat, whitening=dataloader.dataset.whitening, device=self.device)
        self.actor.train()

        return mse

    def state_dict(self) -> Dict[str, Any]:
        return {
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.total_it = state_dict["total_it"]


def run_BC(config: TrainConfig):
    train_dataset = Buffer(
        data_folder=config.data_folder,
        env=config.env,
        split='train',
        data_size=config.data_size,
        device=config.device,
        whitening=config.whitening,
        single_task=config.single_task,
        is_carla="carla" in config.env
    )
    val_dataset = Buffer(
        data_folder=config.data_folder,
        env=config.env,
        split='test',
        data_size=config.data_size,
        device=config.device,
        single_task=config.single_task,
        is_carla="carla" in config.env
    )
    val_dataset.whitening = train_dataset.whitening

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

    # Set seeds
    seed = config.seed
    set_seed(seed)

    state_dim = train_dataset.get_state_dim()
    action_dim = train_dataset.get_action_dim()
    actor = Actor(state_dim, action_dim, arch=config.arch).to(config.device)

    # Setup lr sheduler for Carla experiments
    if config.optimizer == 'sgd' and "carla" in config.env:
        actor_optimizer = torch.optim.SGD(
            actor.parameters(),
            lr=config.lr,
            momentum=0.9
        )
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            actor_optimizer,
            milestones=[int(0.5 * config.max_epochs),
                        int(0.75 * config.max_epochs)],
            gamma=0.1
        )
    else:
        actor_optimizer = {
            'adam': torch.optim.Adam,
            'sgd': torch.optim.SGD
        }[config.optimizer](actor.parameters(), lr=config.lr)
        scheduler = None

    kwargs = {
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "lamW": config.lamW,
        "device": config.device
    }

    print("---------------------------------------")
    print(f"Training BC, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize policy
    trainer = BC(**kwargs)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(policy_file))
        actor = trainer.actor

    record_n = 1
    mses = {'train_mses': [], 'val_mses': []}

    for epoch in range(config.max_epochs):
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{config.max_epochs} Training"):
            trainer.train(batch)

        if scheduler is not None:
            scheduler.step()

        if epoch >= config.max_epochs - record_n:
            actor.eval()

            mses['train_mses'].append(trainer.NC_eval(train_loader))
            mses['val_mses'].append(trainer.NC_eval(val_loader))

            actor.train()

    dirname = f'E{config.env}/whitening_{config.whitening}/S{seed}/dim_{config.single_task}'
    os.makedirs(dirname, exist_ok=True)
    os.chdir(dirname)

    # save the mses
    os.makedirs(f'mses', exist_ok=True)

    with open(f'mses/lamW_{config.lamW}.pkl', 'wb') as file:
        pickle.dump(mses, file)

    # save the actor weights
    os.makedirs(f'weights', exist_ok=True)

    with open(f'weights/lamW_{config.lamW}.pkl', 'wb') as file:
        torch.save(actor.state_dict(), file)
