import copy
import random
import math
import os
from typing import List, Tuple, Sequence

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from tasks.task_generator import configure_dataset
import hydra
from omegaconf import DictConfig

############################################################
# Activations ‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑ #
############################################################

class CReLU(nn.Module):
    """Concatenated ReLU (CReLU) for 2‑D or 4‑D tensors."""
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:        # [B, F]
            x = torch.cat((x, -x), dim=-1)
        elif x.dim() == 4:      # [B, C, H, W]
            x = torch.cat((x, -x), dim=1)
        else:
            raise ValueError(f"CReLU expects (B, F) or (B, C, H, W); got {x.shape}")
        return F.relu(x)

############################################################
# MLP variants ‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑ #
############################################################

class MLP(nn.Module):
    """Plain ReLU MLP (kept unchanged)."""

    def __init__(self, input_dim: int, layer_widths: List[int], output_dim: int):
        super().__init__()
        sizes = [input_dim] + layer_widths + [output_dim]
        self.layers = nn.ModuleList(
            [nn.Linear(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)]
        )

    # ------------------------------------------------------------------
    def forward(
        self, x: torch.Tensor, params: Sequence[torch.Tensor] | None = None
    ) -> torch.Tensor:
        x = x.view(x.shape[0], -1)
        if params is None:
            for i, layer in enumerate(self.layers):
                x = layer(x)
                if i < len(self.layers) - 1:
                    x = F.relu(x)
            return x

        idx = 0
        for i in range(len(self.layers)):
            w, b = params[idx], params[idx + 1]
            idx += 2
            x = F.linear(x, w, b)
            if i < len(self.layers) - 1:
                x = F.relu(x)
        return x


class CReLUMLP(nn.Module):
    """CReLU MLP (kept unchanged)."""

    def __init__(self, input_dim: int, layer_widths: List[int], output_dim: int):
        super().__init__()
        self.crelu = CReLU()
        layers: List[nn.Module] = []

        in_dim = input_dim
        for h in layer_widths:
            layers.append(nn.Linear(in_dim, h // 2))  # halve, then CReLU doubles
            in_dim = h

        layers.append(nn.Linear(in_dim, output_dim))
        self.layers = nn.ModuleList(layers)

    # ------------------------------------------------------------------
    def forward(
        self, x: torch.Tensor, params: Sequence[torch.Tensor] | None = None
    ) -> torch.Tensor:
        x = x.view(x.shape[0], -1)
        if params is None:
            for i, layer in enumerate(self.layers):
                x = layer(x)
                if i < len(self.layers) - 1:
                    x = self.crelu(x)
            return x

        idx = 0
        for i in range(len(self.layers)):
            w, b = params[idx], params[idx + 1]
            idx += 2
            x = F.linear(x, w, b)
            if i < len(self.layers) - 1:
                x = self.crelu(x)
        return x

############################################################
# CNN variants ‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑‑ #
############################################################

def _make_pool(pool_type: str, k: int) -> nn.Module:
    if pool_type.lower() == "max":
        return nn.MaxPool2d(kernel_size=k)
    if pool_type.lower() == "avg":
        return nn.AvgPool2d(kernel_size=k)
    raise ValueError(f"Unsupported pooling type '{pool_type}'")


class CNN(nn.Module):
    """
    Standard CNN followed by a single fully‑connected layer.

    Parameters
    ----------
    input_shape : (C, H, W)
    cnn_channels : list[int]
    kernel_size  : list[int]
    padding      : list[int]
    stride       : list[int]
    pooling_type : list[str]  (e.g. "max", "avg")
    pooling_kernel : list[int]
    output_dim   : int        # classes
    """

    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        cnn_channels: List[int],
        kernel_size: List[int],
        padding: List[int],
        stride: List[int],
        pooling_type: List[str],
        pooling_kernel: List[int],
        output_dim: int,
    ):
        super().__init__()
        C_in, _, _ = input_shape

        self.convs = nn.ModuleList()
        self.pools = nn.ModuleList()

        in_ch = C_in
        for out_ch, k, p, s, pool_t, pool_k in zip(
            cnn_channels, kernel_size, padding, stride, pooling_type, pooling_kernel
        ):
            self.convs.append(
                nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, stride=s)
            )
            self.pools.append(_make_pool(pool_t, pool_k))
            in_ch = out_ch

        # Figure out flattened size once
        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            f = self._forward_features(dummy)
            self.flat_dim = f.numel()

        self.fc = nn.Linear(self.flat_dim, output_dim)

    # ------------------------------------------------------------------
    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        for conv, pool in zip(self.convs, self.pools):
            x = F.relu(conv(x))
            x = pool(x)
        return x

    def forward(
        self, x: torch.Tensor, params: Sequence[torch.Tensor] | None = None
    ) -> torch.Tensor:
        if params is None:
            x = self._forward_features(x)
            x = x.view(x.size(0), -1)
            return self.fc(x)

        # ----- manual weight injection for MAML -----
        idx = 0
        for conv, pool in zip(self.convs, self.pools):
            w, b = params[idx], params[idx + 1]
            idx += 2
            x = F.conv2d(x, w, b, stride=conv.stride, padding=conv.padding)
            x = F.relu(x)
            x = pool(x)
        w_fc, b_fc = params[idx], params[idx + 1]
        x = x.view(x.size(0), -1)
        return F.linear(x, w_fc, b_fc)


class CReLUCNN(nn.Module):
    """
    CNN that uses CReLU after every convolution.  Each conv layer’s *weight*
    produces half the requested channels, then CReLU doubles them back.
    """

    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        cnn_channels: List[int],
        kernel_size: List[int],
        padding: List[int],
        stride: List[int],
        pooling_type: List[str],
        pooling_kernel: List[int],
        output_dim: int,
    ):
        super().__init__()
        C_in, _, _ = input_shape
        self.crelu = CReLU()

        self.convs = nn.ModuleList()
        self.pools = nn.ModuleList()

        in_ch = C_in
        for out_ch_target, k, p, s, pool_t, pool_k in zip(
            cnn_channels, kernel_size, padding, stride, pooling_type, pooling_kernel
        ):
            # halve → CReLU doubles
            out_ch = out_ch_target // 2
            self.convs.append(
                nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, stride=s)
            )
            self.pools.append(_make_pool(pool_t, pool_k))
            in_ch = out_ch_target  # after CReLU doubling

        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            f = self._forward_features(dummy)
            self.flat_dim = f.numel()

        self.fc = nn.Linear(self.flat_dim, output_dim)

    # ------------------------------------------------------------------
    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        for conv, pool in zip(self.convs, self.pools):
            x = self.crelu(conv(x))
            x = pool(x)
        return x

    def forward(
        self, x: torch.Tensor, params: Sequence[torch.Tensor] | None = None
    ) -> torch.Tensor:
        if params is None:
            x = self._forward_features(x)
            x = x.view(x.size(0), -1)
            return self.fc(x)

        idx = 0
        for conv, pool in zip(self.convs, self.pools):
            w, b = params[idx], params[idx + 1]
            idx += 2
            x = F.conv2d(x, w, b, stride=conv.stride, padding=conv.padding)
            x = self.crelu(x)
            x = pool(x)
        w_fc, b_fc = params[idx], params[idx + 1]
        x = x.view(x.size(0), -1)
        return F.linear(x, w_fc, b_fc)

############################################################
# Inner‑loop adaptation (unchanged) ----------------------- #
############################################################

def inner_loop(
    model: nn.Module,
    loss_fn,
    support_x: torch.Tensor,
    support_y: torch.Tensor,
    inner_lr: float,
    inner_steps: int,
):
    """Differentiable SGD on the support set; returns adapted parameters."""
    params = [p for p in model.parameters()]
    for _ in range(inner_steps):
        preds = model.forward(support_x, params)
        loss = loss_fn(preds, support_y)
        grads = torch.autograd.grad(loss, params, create_graph=True)
        params = [p - inner_lr * g for p, g in zip(params, grads)]
    return params

############################################################
# Meta‑training loop (unchanged) -------------------------- #
############################################################

def meta_train(
    meta_model: nn.Module,
    n_meta_epochs: int,
    meta_batch_size: int,
    inner_steps: int,
    inner_lr: float,
    meta_lr: float,
    sample_task,
    support_bs: int | None = 256,
    query_bs: int | None = 256,
    device: str | torch.device = "cpu",
    save_path: str = "maml_init.pth",
):
    """Outer‑loop (meta) optimisation for MAML using dataset outputs from ``sample_task``."""
    device = torch.device(device)
    meta_model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    meta_opt = torch.optim.Adam(meta_model.parameters(), lr=meta_lr)

    task_id = 0
    for epoch in range(1, n_meta_epochs + 1):
        meta_opt.zero_grad()
        meta_loss = 0.0

        for _ in range(meta_batch_size):
            task = sample_task(task_id)
            task_id += 1
            if len(task) == 3:
                support_ds, query_ds, _ = task  # ignore extra info
            else:
                support_ds, query_ds = task

            s_loader = DataLoader(
                support_ds, batch_size=support_bs or len(support_ds), shuffle=True
            )
            q_loader = DataLoader(
                query_ds, batch_size=query_bs or len(query_ds), shuffle=False
            )

            sx, sy = next(iter(s_loader))
            qx, qy = next(iter(q_loader))
            sx, sy, qx, qy = (
                sx.to(device),
                sy.to(device),
                qx.to(device),
                qy.to(device),
            )

            adapted_params = inner_loop(
                meta_model, loss_fn, sx, sy, inner_lr, inner_steps
            )

            q_preds = meta_model.forward(qx, adapted_params)
            loss_q = loss_fn(q_preds, qy)
            meta_loss += loss_q

        meta_loss /= meta_batch_size
        meta_loss.backward()
        meta_opt.step()

        if epoch % 10 == 0:
            print(f"[Meta‑Epoch {epoch:04d}] meta‑loss = {meta_loss.item():.4f}")

    print("save_path:", save_path, type(save_path))
    torch.save(meta_model.state_dict(), save_path)
    print(f"\n✔️  Meta‑learned initialisation saved to '{save_path}'.")

############################################################
# Entry point (unchanged) -------------------------------- #
############################################################

@hydra.main(config_path="configs/sl", config_name="config")
def main(cfg: DictConfig):
    agent_config = cfg.agent
    arch_config = cfg.arch
    task_config = cfg.task

    INPUT_SHAPE = task_config.input_shape  # (C,H,W)
    OUTPUT_DIM = task_config.num_classes

    INPUT_DIM = fc_in = math.prod(task_config.input_shape)
    LAYER_WIDTHS = arch_config.fc_channels

    n_meta_epochs = cfg.meta_train_epoch

    # ------------------------------------------------------------------
    # Choose one of the architectures below.
    # ------------------------------------------------------------------

    if agent_config.agent_type == 'CReLUAgent':
        if arch_config.arch_name == 'CNN':
            model = CReLUCNN(
                input_shape=INPUT_SHAPE,
                cnn_channels=arch_config.cnn_channels,          # e.g. [8,16,32,64]
                kernel_size=arch_config.kernel_size,            # [3,3,3,3]
                padding=arch_config.padding,                    # [1,1,1,1]
                stride=arch_config.stride,                      # [1,1,1,1]
                pooling_type=arch_config.pooling_type,          # ['max',...]
                pooling_kernel=arch_config.pooling_kernel,      # [2,2,2,2]
                output_dim=OUTPUT_DIM,
            )
        elif arch_config.arch_name == 'MLP':
            model = CReLUMLP(INPUT_DIM, LAYER_WIDTHS, OUTPUT_DIM)
    elif agent_config.agent_type == 'BaseAgent':
        if arch_config.arch_name == 'CNN':
            model = CNN(
                input_shape=INPUT_SHAPE,
                cnn_channels=arch_config.cnn_channels,          # e.g. [8,16,32,64]
                kernel_size=arch_config.kernel_size,            # [3,3,3,3]
                padding=arch_config.padding,                    # [1,1,1,1]
                stride=arch_config.stride,                      # [1,1,1,1]
                pooling_type=arch_config.pooling_type,          # ['max',...]
                pooling_kernel=arch_config.pooling_kernel,      # [2,2,2,2]
                output_dim=OUTPUT_DIM,
            )
        elif arch_config.arch_name == 'MLP':
            model = MLP(INPUT_DIM, LAYER_WIDTHS, OUTPUT_DIM)

    print(model)

    get_task_dataset, _, _ = configure_dataset(
        task_config=task_config, arch_config=arch_config, args=cfg
    )

    save_dir = "/localhome/srr8/projects/17/maml_model"
    os.makedirs(save_dir, exist_ok=True)

    meta_train(
        meta_model=model,
        n_meta_epochs=n_meta_epochs,
        meta_batch_size=4,
        inner_steps=5,
        inner_lr=0.01,
        meta_lr=1e-3,
        sample_task=get_task_dataset,
        device="cuda" if torch.cuda.is_available() else "cpu",
        save_path=f"{save_dir}/{arch_config.arch_name}_{agent_config.agent_type}_{task_config.benchmark}_{n_meta_epochs}_init.pth",
    )

#python maml_main_sl.py task=shuffle_cifar10 agent=baseline wandb=true monitor_backward_transfer=false proj_name=new_train agent.optimizer=adam agent.lr=0.001 arch=mlp
if __name__ == "__main__":
    main()
