import numpy as np
import plotly.graph_objects as go
import torch
from utils.data import DiffusionRContourDataset, GaussianDiffusion
from utils.orchestrator import TrainerRContour

# Set seed for reproducibility
np.random.seed(42)

# flags
visualise_plot = True
train = True
evaluate = True

# Dataset parameters
radius = 0.5  # Radius of the semicircle
n_data = 80  # Number of points to generate

config = TrainerRContour.default_config

# exptname
expt_name = 'ce_heatmap'
config['expt_name'] = expt_name

# train data
theta_train_1 = np.random.uniform(np.pi/8, np.pi, n_data // 2)
x_train_1, y_train_1 = radius * np.cos(theta_train_1) - radius, radius * np.sin(theta_train_1)

theta_train_2 = np.random.uniform(9*np.pi/8, 2*np.pi, n_data // 2)
x_train_2, y_train_2 = radius * np.cos(theta_train_2) + radius, radius * np.sin(theta_train_2)

x_train = np.concatenate([x_train_1, x_train_2])
y_train = np.concatenate([y_train_1, y_train_2])
train_sequences = torch.tensor(np.column_stack((x_train, y_train)), dtype=torch.float32)
dataset_train = DiffusionRContourDataset(config, train_sequences)


# val data
theta_val_1 = np.random.uniform(np.pi/8, np.pi, n_data // 2)
x_val_1, y_val_1 = radius * np.cos(theta_val_1) - radius, radius * np.sin(theta_val_1)

theta_val_2 = np.random.uniform(9*np.pi/8, 2*np.pi, n_data // 2)
x_val_2, y_val_2 = radius * np.cos(theta_val_2) + radius, radius * np.sin(theta_val_2)

x_val = np.concatenate([x_val_1, x_val_2])
y_val = np.concatenate([y_val_1, y_val_2])
val_sequences = torch.tensor(np.column_stack((x_val, y_val)), dtype=torch.float32)
dataset_val = DiffusionRContourDataset(config, val_sequences)


if visualise_plot:
    # scatter
    scatter_train = go.Scatter(x=x_train, y=y_train, mode='markers', marker=dict(color='blue'), name='Train')
    scatter_val = go.Scatter(x=x_val, y=y_val, mode='markers', marker=dict(color='red'), name='Val')

    # plot
    fig = go.Figure(data=[scatter_train, scatter_val])
    fig.update_layout(xaxis=dict(range=[-2, 2]), yaxis=dict(range=[-2, 2]), showlegend=True, width=800, height=800)
    fig.show()

# train
if train:
    trainer = TrainerRContour(config)
    trainer.train(dataset_train, dataset_val, use_wandb=True, save_best_val_model=False)    # note set save_best_val model=False means just train to completion
    print(f'Experiment name = {config["expt_name"]}')

if evaluate:
    # load trainer
    trainer, info, data_train = TrainerRContour.load(expt_name)

    # plot CE for each point on n x n grid
    n = 100
    xlin = np.linspace(-2, 2, n)
    ylin = np.linspace(-2, 2, n)

    xx, yy = np.meshgrid(xlin, ylin)
    points = np.stack([xx.ravel(), yy.ravel()], axis=1)
    points = torch.tensor(points).unsqueeze(-1).float()

    trainer.predictor.data_means = torch.zeros(2, 1).float()
    trainer.predictor.data_stds = torch.ones(2, 1).float()

    ce_errors = trainer.predictor.compute_combined_error(x_sampled=points)
    ce_grid = ce_errors.view(n, n).numpy()

    heatmap_ce = go.Heatmap(
        z=ce_grid,
        x=xlin,
        y=ylin,
        colorscale='Viridis',
        colorbar=dict(title='Combined Error')
    )

    # Create scatter plot for training data
    data_train = np.array(data_train)
    scatter = go.Scatter(
        x=data_train[:, 0],
        y=data_train[:, 1],
        mode='markers',
        marker=dict(color='red', symbol='x', size=5),  # Setting marker as 'x'
        name='Samples'
    )

    # Generate figure
    fig = go.Figure(data=[heatmap_ce, scatter])
    fig.update_layout(
        title='Combined Error Heatmap',
        xaxis_title='',
        yaxis_title='',
        xaxis=dict(
            showgrid=False,
            zeroline=False,
            showticklabels=False,
            visible=False
        ),
        yaxis=dict(
            showgrid=False,
            zeroline=False,
            showticklabels=False,
            scaleanchor='x',
            visible=False
        ),
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=0, r=0, t=30, b=0)
    )
    fig.show()
