import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import accuracy_score, brier_score_loss

import matplotlib.pyplot as plt

# --------- Data ----------

def generate_circle_line(
	N_train=1000,
	N_test=500,
	seed=1,
	dtype=np.float32,
	t=1,
):
	offset_linear=(-t, 0.0)
	offset_circle=(t, 0.0)

	rng = np.random.RandomState(seed)
	N1 = N_train // 2
	N2 = N_train - N1
	# linear region: Gaussian around offset_linear
	X1 = np.random.multivariate_normal([0.0, 0.0], 0.1*np.eye(2), size=N1) + np.array(offset_linear)
	y1 = (X1[:, 0] + X1[:, 1] > np.sum(offset_linear)).astype(int)
	types1 = np.zeros(N1, dtype=int)
	# circular region: ring centered at offset_circle
	theta = rng.uniform(0, 2*np.pi, size=N2)

	r = np.sqrt(rng.uniform(0, 2, size=N2))
	X2 = np.vstack([r * np.cos(theta), r * np.sin(theta)]).T + np.array(offset_circle)
	y2 = (r < 1).astype(int)
	types2 = np.ones(N2, dtype=int)

	x_train = np.vstack([X1, X2]).astype(dtype)
	y_train = np.concatenate([y1, y2]).reshape(-1,1).astype(dtype)
	types_train = np.concatenate([types1, types2]).reshape(-1,1)
	perm = rng.permutation(N_train)
	x_train, y_train, types_train = x_train[perm], y_train[perm], types_train[perm]

	# test set with same offsets
	rng2 = np.random.RandomState(seed+1)
	N1t = N_test // 2
	N2t = N_test - N1t
	Xt1 = rng2.randn(N1t, 2) + np.array(offset_linear)
	yt1 = (Xt1[:, 0] + Xt1[:, 1] > np.sum(offset_linear)).astype(int)
	types1t = np.zeros(N1t, dtype=int)
	theta = rng2.uniform(0, 2*np.pi, size=N2t)

	r = np.sqrt(rng2.uniform(0, 2, size=N2t))
	Xt2 = np.vstack([r * np.cos(theta), r * np.sin(theta)]).T + np.array(offset_circle)
	yt2 = (r < 1).astype(int)
	types2t = np.ones(N2t, dtype=int)

	x_test = np.vstack([Xt1, Xt2]).astype(dtype)
	y_test = np.concatenate([yt1, yt2]).reshape(-1,1).astype(dtype)
	types_test = np.concatenate([types1t, types2t]).reshape(-1,1)
	perm2 = rng2.permutation(N_test)
	x_test, y_test, types_test = x_test[perm2], y_test[perm2], types_test[perm2]

	y_oh_train = tf.one_hot(y_train.ravel(), depth=2)
	y_oh_test = tf.one_hot(y_test.ravel(), depth=2)

	return x_train, y_train, y_oh_train, types_train, x_test, y_test, y_oh_test, types_test


def plot_circle_line(x_test,y_test, types_test, path=None):
    markers = {0: 'o', 1: 's'}
    colors  = {0: 'C0', 1: 'C1'}
    
    plt.figure(figsize=(6,4))
    for t in [0, 1]:
        for cls in [0, 1]:
            mask = (types_test.flatten() == t) & (y_test.flatten() == cls)
            plt.scatter(
                x_test[mask, 0],
                x_test[mask, 1],
                marker=markers[t],
                color=colors[cls],
                label=f"type={t}, class={cls}",
                edgecolor='k',
                alpha=0.6,
                s=50
            )
    plt.gca().set_aspect('equal', adjustable='box')
    if path is not None:
        plt.savefig(path, format="pdf")
    plt.show()
    
        
        
# --------- Predictors ----------

class SoftCircleClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, t=1.0, radius=1.0, gamma=5.0): 
        self.t      = float(t)
        self.center = np.array((0.8 * t, 0.0), dtype=float)
        self.radius = float(radius)
        self.gamma  = float(gamma)

    def fit(self, X, y):
        self.classes_ = np.array([0, 1], dtype=int)
        return self

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] > 0.5).astype(int)

    def predict_proba(self, X):
        # distance from center
        d2 = np.sum((X - self.center)**2, axis=1)
        r  = np.sqrt(d2)
        # logistic on (radius - r)
        logits = self.gamma * (self.radius - r)
        p1 = 1.0 / (1.0 + np.exp(-logits))
        p0 = 1.0 - p1
        return np.vstack([p0, p1]).T
        
def baseline_models(t=1):
    models = {
        "poly2_logreg": Pipeline([
            ("poly", PolynomialFeatures(degree=2, include_bias=False)),
            ("logreg", LogisticRegression(max_iter=1000))]),
        "poly3_logreg": Pipeline([
            ("poly", PolynomialFeatures(degree=3, include_bias=False)),
            ("logreg", LogisticRegression(max_iter=1000))]),
        "circle1": SoftCircleClassifier(t),
        "circle2": SoftCircleClassifier(t),
        "lda": LinearDiscriminantAnalysis(),
    }
    return models

    
def train_models(x, y, models):
    if y.ndim > 1: y = y.ravel()
    fitted = {}
    for name, mdl in models.items():
        print(f"Training {name}")
        mdl.fit(x, y)
        fitted[name] = mdl
    return fitted
    
    
# --------- Tailored Evaluation ----------

def accuracy_by_region(x_test, y_test, types_test, fitted_models):
    model_region_acc = {}
    for name, mdl in fitted_models.items():
        acc_lin = accuracy_score(
            y_test[types_test.ravel() == 0].ravel(),
            mdl.predict(x_test[types_test.ravel() == 0])
        )
        acc_circ = accuracy_score(
            y_test[types_test.ravel() == 1].ravel(),
            mdl.predict(x_test[types_test.ravel() == 1])
        )
        model_region_acc[name] = (acc_lin, acc_circ)
    model_region_acc = pd.DataFrame(model_region_acc)
    model_region_acc = model_region_acc.T
    model_region_acc.columns = ["Linear Accuracy", "Circular Accuracy"]
    return model_region_acc
    